diff --git a/client_conn.go b/client_conn.go index c5c7d14..694cb62 100644 --- a/client_conn.go +++ b/client_conn.go @@ -56,7 +56,10 @@ func (c *clientConn) Write(b []byte) (n int, err error) { } buffer := buf.NewSize(streamRequestLen(request) + len(b)) defer buffer.Release() - EncodeStreamRequest(request, buffer) + err = EncodeStreamRequest(request, buffer) + if err != nil { + return + } buffer.Write(b) _, err = c.Conn.Write(buffer.Bytes()) if err != nil { @@ -143,7 +146,10 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { } buffer := buf.NewSize(rLen) defer buffer.Release() - EncodeStreamRequest(request, buffer) + err = EncodeStreamRequest(request, buffer) + if err != nil { + return + } if len(payload) > 0 { common.Must( binary.Write(buffer, binary.BigEndian, uint16(len(payload))), @@ -340,10 +346,16 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa } buffer := buf.NewSize(rLen) defer buffer.Release() - EncodeStreamRequest(request, buffer) + err = EncodeStreamRequest(request, buffer) + if err != nil { + return + } if len(payload) > 0 { + err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) + if err != nil { + return + } common.Must( - M.SocksaddrSerializer.WriteAddrPort(buffer, destination), binary.Write(buffer, binary.BigEndian, uint16(len(payload))), common.Error(buffer.Write(payload)), ) @@ -411,10 +423,11 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc } bLen := buffer.Len() header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) - common.Must( - M.SocksaddrSerializer.WriteAddrPort(header, destination), - binary.Write(header, binary.BigEndian, uint16(bLen)), - ) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + common.Must(binary.Write(header, binary.BigEndian, uint16(bLen))) return c.ExtendedConn.WriteBuffer(buffer) } diff --git a/protocol.go b/protocol.go index 106b39c..d93268b 100644 --- a/protocol.go +++ b/protocol.go @@ -143,7 +143,7 @@ func streamRequestLen(request StreamRequest) int { return rLen } -func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) { +func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) error { destination := request.Destination var flags uint16 if request.Network == N.NetworkUDP { @@ -155,10 +155,8 @@ func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) { destination = Destination } } - common.Must( - binary.Write(buffer, binary.BigEndian, flags), - M.SocksaddrSerializer.WriteAddrPort(buffer, destination), - ) + common.Must(binary.Write(buffer, binary.BigEndian, flags)) + return M.SocksaddrSerializer.WriteAddrPort(buffer, destination) } type StreamResponse struct { diff --git a/server_conn.go b/server_conn.go index bac7add..dfe6f70 100644 --- a/server_conn.go +++ b/server_conn.go @@ -276,7 +276,10 @@ func (c *serverPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { pLen := buffer.Len() common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) - common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) + err := M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination) + if err != nil { + return err + } if !c.responseWritten { c.access.Lock() if c.responseWritten {