Remove bad rw usages

This commit is contained in:
世界 2024-06-23 21:07:23 +08:00
parent d8ec9c46cc
commit e0196407a3
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 74 additions and 62 deletions

View file

@ -1,6 +1,7 @@
package socks
import (
std_bufio "bufio"
"context"
"io"
"net"
@ -13,7 +14,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
@ -32,7 +33,7 @@ func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr,
if err != nil {
return socks4.Response{}, err
}
response, err := socks4.ReadResponse(conn)
response, err := socks4.ReadResponse(varbin.StubReader(conn))
if err != nil {
return socks4.Response{}, err
}
@ -43,6 +44,7 @@ func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr,
}
func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) {
reader := varbin.StubReader(conn)
var method byte
if username == "" {
method = socks5.AuthTypeNotRequired
@ -55,7 +57,7 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
if err != nil {
return socks5.Response{}, err
}
authResponse, err := socks5.ReadAuthResponse(conn)
authResponse, err := socks5.ReadAuthResponse(reader)
if err != nil {
return socks5.Response{}, err
}
@ -67,7 +69,7 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
if err != nil {
return socks5.Response{}, err
}
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(conn)
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(reader)
if err != nil {
return socks5.Response{}, err
}
@ -84,7 +86,7 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
if err != nil {
return socks5.Response{}, err
}
response, err := socks5.ReadResponse(conn)
response, err := socks5.ReadResponse(reader)
if err != nil {
return socks5.Response{}, err
}
@ -95,17 +97,17 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
}
func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
version, err := rw.ReadByte(conn)
return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata)
}
func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
version, err := reader.ReadByte()
if err != nil {
return err
}
return HandleConnection0(ctx, conn, version, authenticator, handler, metadata)
}
func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
switch version {
case socks4.Version:
request, err := socks4.ReadRequest0(conn)
request, err := socks4.ReadRequest0(reader)
if err != nil {
return err
}
@ -142,7 +144,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
return E.New("socks4: unsupported command ", request.Command)
}
case socks5.Version:
authRequest, err := socks5.ReadAuthRequest0(conn)
authRequest, err := socks5.ReadAuthRequest0(reader)
if err != nil {
return err
}
@ -167,7 +169,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
return err
}
if authMethod == socks5.AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn)
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader)
if err != nil {
return err
}
@ -186,7 +188,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password)
}
}
request, err := socks5.ReadRequest(conn)
request, err := socks5.ReadRequest(reader)
if err != nil {
return err
}

View file

@ -10,7 +10,7 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)
const (
@ -31,8 +31,8 @@ type Request struct {
Username string
}
func ReadRequest(reader io.Reader) (request Request, err error) {
version, err := rw.ReadByte(reader)
func ReadRequest(reader varbin.Reader) (request Request, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -43,8 +43,8 @@ func ReadRequest(reader io.Reader) (request Request, err error) {
return ReadRequest0(reader)
}
func ReadRequest0(reader io.Reader) (request Request, err error) {
request.Command, err = rw.ReadByte(reader)
func ReadRequest0(reader varbin.Reader) (request Request, err error) {
request.Command, err = reader.ReadByte()
if err != nil {
return
}
@ -108,7 +108,7 @@ func WriteRequest(writer io.Writer, request Request) error {
common.Must1(buffer.WriteString(request.Destination.AddrString()))
common.Must(buffer.WriteZero())
}
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}
type Response struct {
@ -116,8 +116,8 @@ type Response struct {
Destination M.Socksaddr
}
func ReadResponse(reader io.Reader) (response Response, err error) {
version, err := rw.ReadByte(reader)
func ReadResponse(reader varbin.Reader) (response Response, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -125,7 +125,7 @@ func ReadResponse(reader io.Reader) (response Response, err error) {
err = E.New("excepted socks4 response version 0, got ", version)
return
}
response.ReplyCode, err = rw.ReadByte(reader)
response.ReplyCode, err = reader.ReadByte()
if err != nil {
return
}
@ -151,13 +151,13 @@ func WriteResponse(writer io.Writer, response Response) error {
binary.Write(buffer, binary.BigEndian, response.Destination.Port),
common.Error(buffer.Write(response.Destination.Addr.AsSlice())),
)
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}
func readString(reader io.Reader) (string, error) {
func readString(reader varbin.Reader) (string, error) {
buffer := bytes.Buffer{}
for {
b, err := rw.ReadByte(reader)
b, err := reader.ReadByte()
if err != nil {
return "", err
}

View file

@ -8,7 +8,7 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)
const (
@ -55,11 +55,11 @@ func WriteAuthRequest(writer io.Writer, request AuthRequest) error {
buffer.WriteByte(byte(len(request.Methods))),
common.Error(buffer.Write(request.Methods)),
)
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}
func ReadAuthRequest(reader io.Reader) (request AuthRequest, err error) {
version, err := rw.ReadByte(reader)
func ReadAuthRequest(reader varbin.Reader) (request AuthRequest, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -70,12 +70,13 @@ func ReadAuthRequest(reader io.Reader) (request AuthRequest, err error) {
return ReadAuthRequest0(reader)
}
func ReadAuthRequest0(reader io.Reader) (request AuthRequest, err error) {
methodLen, err := rw.ReadByte(reader)
func ReadAuthRequest0(reader varbin.Reader) (request AuthRequest, err error) {
methodLen, err := reader.ReadByte()
if err != nil {
return
}
request.Methods, err = rw.ReadBytes(reader, int(methodLen))
request.Methods = make([]byte, methodLen)
_, err = io.ReadFull(reader, request.Methods)
return
}
@ -90,11 +91,11 @@ type AuthResponse struct {
}
func WriteAuthResponse(writer io.Writer, response AuthResponse) error {
return rw.WriteBytes(writer, []byte{Version, response.Method})
return common.Error(writer.Write([]byte{Version, response.Method}))
}
func ReadAuthResponse(reader io.Reader) (response AuthResponse, err error) {
version, err := rw.ReadByte(reader)
func ReadAuthResponse(reader varbin.Reader) (response AuthResponse, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -102,7 +103,7 @@ func ReadAuthResponse(reader io.Reader) (response AuthResponse, err error) {
err = E.New("expected socks version 5, got ", version)
return
}
response.Method, err = rw.ReadByte(reader)
response.Method, err = reader.ReadByte()
return
}
@ -125,11 +126,11 @@ func WriteUsernamePasswordAuthRequest(writer io.Writer, request UsernamePassword
M.WriteSocksString(buffer, request.Username),
M.WriteSocksString(buffer, request.Password),
)
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}
func ReadUsernamePasswordAuthRequest(reader io.Reader) (request UsernamePasswordAuthRequest, err error) {
version, err := rw.ReadByte(reader)
func ReadUsernamePasswordAuthRequest(reader varbin.Reader) (request UsernamePasswordAuthRequest, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -159,11 +160,11 @@ type UsernamePasswordAuthResponse struct {
}
func WriteUsernamePasswordAuthResponse(writer io.Writer, response UsernamePasswordAuthResponse) error {
return rw.WriteBytes(writer, []byte{1, response.Status})
return common.Error(writer.Write([]byte{1, response.Status}))
}
func ReadUsernamePasswordAuthResponse(reader io.Reader) (response UsernamePasswordAuthResponse, err error) {
version, err := rw.ReadByte(reader)
func ReadUsernamePasswordAuthResponse(reader varbin.Reader) (response UsernamePasswordAuthResponse, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -171,7 +172,7 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (response UsernamePasswo
err = E.New("excepted password request version 1, got ", version)
return
}
response.Status, err = rw.ReadByte(reader)
response.Status, err = reader.ReadByte()
return
}
@ -198,11 +199,11 @@ func WriteRequest(writer io.Writer, request Request) error {
if err != nil {
return err
}
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}
func ReadRequest(reader io.Reader) (request Request, err error) {
version, err := rw.ReadByte(reader)
func ReadRequest(reader varbin.Reader) (request Request, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -210,11 +211,11 @@ func ReadRequest(reader io.Reader) (request Request, err error) {
err = E.New("expected socks version 5, got ", version)
return
}
request.Command, err = rw.ReadByte(reader)
request.Command, err = reader.ReadByte()
if err != nil {
return
}
err = rw.Skip(reader)
_, err = reader.ReadByte()
if err != nil {
return
}
@ -252,11 +253,11 @@ func WriteResponse(writer io.Writer, response Response) error {
if err != nil {
return err
}
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}
func ReadResponse(reader io.Reader) (response Response, err error) {
version, err := rw.ReadByte(reader)
func ReadResponse(reader varbin.Reader) (response Response, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
@ -264,11 +265,11 @@ func ReadResponse(reader io.Reader) (response Response, err error) {
err = E.New("expected socks version 5, got ", version)
return
}
response.ReplyCode, err = rw.ReadByte(reader)
response.ReplyCode, err = reader.ReadByte()
if err != nil {
return
}
err = rw.Skip(reader)
_, err = reader.ReadByte()
if err != nil {
return
}