diff --git a/client_conn.go b/client_conn.go index 7aba975..2304a9c 100644 --- a/client_conn.go +++ b/client_conn.go @@ -209,6 +209,37 @@ func (c *clientPacketConn) FrontHeadroom() int { return 2 } +func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + if cap(p) < int(length) { + return 0, nil, io.ErrShortBuffer + } + n, err = io.ReadFull(c.ExtendedConn, p[:length]) + return +} + +func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if !c.requestWritten { + return c.writeRequest(p) + } + err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) + if err != nil { + return + } + return c.ExtendedConn.Write(p) +} + func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { err = c.ReadBuffer(buffer) return diff --git a/server_conn.go b/server_conn.go index 52e9db7..a37b64a 100644 --- a/server_conn.go +++ b/server_conn.go @@ -2,6 +2,8 @@ package mux import ( "encoding/binary" + "io" + "net" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -122,6 +124,33 @@ func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad return c.ExtendedConn.WriteBuffer(buffer) } +func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + if cap(p) < int(length) { + return 0, nil, io.ErrShortBuffer + } + n, err = io.ReadFull(c.ExtendedConn, p[:length]) + return +} + +func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if !c.responseWritten { + _, err = c.ExtendedConn.Write([]byte{statusSuccess}) + if err != nil { + return + } + } + err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) + if err != nil { + return + } + return c.ExtendedConn.Write(p) +} + func (c *serverPacketConn) NeedAdditionalReadDeadline() bool { return true } @@ -160,6 +189,46 @@ func (c *serverPacketAddrConn) HandshakeFailure(err error) error { return c.ExtendedConn.WriteBuffer(buffer) } +func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) + if err != nil { + return + } + if destination.IsFqdn() { + addr = destination + } else { + addr = destination.UDPAddr() + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + if cap(p) < int(length) { + return 0, nil, io.ErrShortBuffer + } + n, err = io.ReadFull(c.ExtendedConn, p[:length]) + return +} + +func (c *serverPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if !c.responseWritten { + _, err = c.ExtendedConn.Write([]byte{statusSuccess}) + if err != nil { + return + } + } + err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) + if err != nil { + return + } + return c.ExtendedConn.Write(p) +} + func (c *serverPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) if err != nil {