fix: AssociatePacketConn.ReadFrom, AssociatePacketConn.ReadPacket, AssociatePacketReadWaiter.WaitReadPacket remoteAddr被conn remoteAddr覆盖问题

This commit is contained in:
pengjunwen 2024-05-10 19:53:20 +08:00
parent 4ab8cac5eb
commit de04c8b788
2 changed files with 6 additions and 6 deletions

View file

@ -49,20 +49,20 @@ func (c *AssociatePacketConn) RemoteAddr() net.Addr {
//warn:unsafe //warn:unsafe
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p) n, _, err = c.NetPacketConn.ReadFrom(p)
if err != nil { if err != nil {
return return
} }
if n < 3 { if n < 3 {
return 0, nil, ErrInvalidPacket return 0, nil, ErrInvalidPacket
} }
c.remoteAddr = M.SocksaddrFromNet(addr)
reader := bytes.NewReader(p[3:n]) reader := bytes.NewReader(p[3:n])
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil { if err != nil {
return return
} }
addr = destination.UDPAddr() addr = destination.UDPAddr()
c.remoteAddr = M.SocksaddrFromNet(addr)
index := 3 + int(reader.Size()) - reader.Len() index := 3 + int(reader.Size()) - reader.Len()
n = copy(p, p[index:n]) n = copy(p, p[index:n])
return return
@ -95,19 +95,19 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
} }
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer) _, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil { if err != nil {
return M.Socksaddr{}, err return M.Socksaddr{}, err
} }
if buffer.Len() < 3 { if buffer.Len() < 3 {
return M.Socksaddr{}, ErrInvalidPacket return M.Socksaddr{}, ErrInvalidPacket
} }
c.remoteAddr = destination
buffer.Advance(3) buffer.Advance(3)
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil { if err != nil {
return return
} }
c.remoteAddr = destination
return destination.Unwrap(), nil return destination.Unwrap(), nil
} }

View file

@ -29,7 +29,7 @@ func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptio
} }
func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = w.readWaiter.WaitReadPacket() buffer, _, err = w.readWaiter.WaitReadPacket()
if err != nil { if err != nil {
return return
} }
@ -37,12 +37,12 @@ func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destin
buffer.Release() buffer.Release()
return nil, M.Socksaddr{}, ErrInvalidPacket return nil, M.Socksaddr{}, ErrInvalidPacket
} }
w.conn.remoteAddr = destination
buffer.Advance(3) buffer.Advance(3)
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return nil, M.Socksaddr{}, err return nil, M.Socksaddr{}, err
} }
w.conn.remoteAddr = destination
return return
} }