From 03c21c0a1205daedc7c4f701b44f5eba8319ddca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Sep 2023 09:06:46 +0800 Subject: [PATCH] Fix WriteAddrPort usage --- common/uot/lazy.go | 14 +++++++++++--- common/uot/protocol.go | 19 ++++++++++++------- protocol/socks/packet.go | 5 ++++- protocol/socks/packet_vectorised.go | 5 ++++- protocol/socks/socks5/protocol.go | 10 ++++++++-- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/common/uot/lazy.go b/common/uot/lazy.go index 1b1b54a..8853564 100644 --- a/common/uot/lazy.go +++ b/common/uot/lazy.go @@ -30,7 +30,11 @@ func NewLazyConn(conn net.Conn, request Request) *Conn { func (c *LazyClientConn) Write(p []byte) (n int, err error) { if !c.requestWritten { - request := EncodeRequest(c.request) + var request *buf.Buffer + request, err = EncodeRequest(c.request) + if err != nil { + return + } err = c.writer.WriteVectorised([]*buf.Buffer{request, buf.As(p)}) if err != nil { return @@ -43,8 +47,12 @@ func (c *LazyClientConn) Write(p []byte) (n int, err error) { func (c *LazyClientConn) WriteVectorised(buffers []*buf.Buffer) error { if !c.requestWritten { - request := EncodeRequest(c.request) - err := c.writer.WriteVectorised(append([]*buf.Buffer{request}, buffers...)) + request, err := EncodeRequest(c.request) + if err != nil { + return err + } + + err = c.writer.WriteVectorised(append([]*buf.Buffer{request}, buffers...)) c.requestWritten = true return err } diff --git a/common/uot/protocol.go b/common/uot/protocol.go index 6ac2d84..4f968c7 100644 --- a/common/uot/protocol.go +++ b/common/uot/protocol.go @@ -51,20 +51,25 @@ func ReadRequest(reader io.Reader) (*Request, error) { return &request, nil } -func EncodeRequest(request Request) *buf.Buffer { +func EncodeRequest(request Request) (*buf.Buffer, error) { var bufferLen int bufferLen += 1 // isConnect bufferLen += M.SocksaddrSerializer.AddrPortLen(request.Destination) buffer := buf.NewSize(bufferLen) - common.Must( - binary.Write(buffer, binary.BigEndian, request.IsConnect), - M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination), - ) - return buffer + common.Must(binary.Write(buffer, binary.BigEndian, request.IsConnect)) + err := M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination) + if err != nil { + buffer.Release() + return nil, err + } + return buffer, nil } func WriteRequest(writer io.Writer, request Request) error { - buffer := EncodeRequest(request) + buffer, err := EncodeRequest(request) + if err != nil { + return err + } defer buffer.Release() return common.Error(writer.Write(buffer.Bytes())) } diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index 555ee79..4df672c 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -104,7 +104,10 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Sock func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination))) common.Must(header.WriteZeroN(3)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)) } diff --git a/protocol/socks/packet_vectorised.go b/protocol/socks/packet_vectorised.go index 12684cc..c73286b 100644 --- a/protocol/socks/packet_vectorised.go +++ b/protocol/socks/packet_vectorised.go @@ -43,7 +43,10 @@ func (v *VectorisedAssociatePacketConn) WriteVectorisedPacket(buffers []*buf.Buf header := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination)) defer header.Release() common.Must(header.WriteZeroN(3)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } return v.VectorisedPacketWriter.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination) } diff --git a/protocol/socks/socks5/protocol.go b/protocol/socks/socks5/protocol.go index bce361b..67d9797 100644 --- a/protocol/socks/socks5/protocol.go +++ b/protocol/socks/socks5/protocol.go @@ -193,8 +193,11 @@ func WriteRequest(writer io.Writer, request Request) error { buffer.WriteByte(Version), buffer.WriteByte(request.Command), buffer.WriteZero(), - M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination), ) + err := M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination) + if err != nil { + return err + } return rw.WriteBytes(writer, buffer.Bytes()) } @@ -244,8 +247,11 @@ func WriteResponse(writer io.Writer, response Response) error { buffer.WriteByte(Version), buffer.WriteByte(response.ReplyCode), buffer.WriteZero(), - M.SocksaddrSerializer.WriteAddrPort(buffer, bind), ) + err := M.SocksaddrSerializer.WriteAddrPort(buffer, bind) + if err != nil { + return err + } return rw.WriteBytes(writer, buffer.Bytes()) }