diff --git a/hysteria/packet.go b/hysteria/packet.go index 6ba8663..a597c76 100644 --- a/hysteria/packet.go +++ b/hysteria/packet.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "io" + "math" "net" "os" "sync" @@ -13,6 +14,7 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" @@ -123,7 +125,7 @@ type udpPacketConn struct { quicConn quic.Connection data chan *udpMessage udpMTU int - packetId uint16 + packetId atomic.Uint32 closeOnce sync.Once defragger *udpDefragger onDestroy func() @@ -182,11 +184,15 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) if buffer.Len() > 0xffff { return quic.ErrMessageTooLarge(0xffff) } - c.packetId++ + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, - packetID: c.packetId, + packetID: uint16(packetId), fragmentTotal: 1, host: destination.AddrString(), port: destination.Port, @@ -218,12 +224,16 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if len(p) > 0xffff { return 0, quic.ErrMessageTooLarge(0xffff) } - c.packetId++ + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } message := allocMessage() destination := M.SocksaddrFromNet(addr) *message = udpMessage{ sessionID: c.sessionID, - packetID: uint16(c.packetId), + packetID: uint16(packetId), fragmentTotal: 1, host: destination.AddrString(), port: destination.Port, diff --git a/hysteria2/packet.go b/hysteria2/packet.go index 8c1522a..e3c1f1c 100644 --- a/hysteria2/packet.go +++ b/hysteria2/packet.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "io" + "math" "net" "os" "sync" @@ -15,6 +16,7 @@ import ( "github.com/sagernet/quic-go/quicvarint" "github.com/sagernet/sing-quic/hysteria2/internal/protocol" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/cache" M "github.com/sagernet/sing/common/metadata" @@ -119,7 +121,7 @@ type udpPacketConn struct { quicConn quic.Connection data chan *udpMessage udpMTU int - packetId uint16 + packetId atomic.Uint32 closeOnce sync.Once defragger *udpDefragger onDestroy func() @@ -178,11 +180,15 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) if buffer.Len() > 0xffff { return quic.ErrMessageTooLarge(0xffff) } - c.packetId++ + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, - packetID: uint16(c.packetId), + packetID: uint16(packetId), fragmentTotal: 1, destination: destination.String(), data: buffer, @@ -213,11 +219,15 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if len(p) > 0xffff { return 0, quic.ErrMessageTooLarge(0xffff) } - c.packetId++ + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, - packetID: uint16(c.packetId), + packetID: uint16(packetId), fragmentTotal: 1, destination: addr.String(), data: buf.As(p), diff --git a/tuic/packet.go b/tuic/packet.go index 0d6dc81..888695f 100644 --- a/tuic/packet.go +++ b/tuic/packet.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "io" + "math" "net" "os" "sync" @@ -13,6 +14,7 @@ import ( "github.com/sagernet/quic-go" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" @@ -127,7 +129,7 @@ type udpPacketConn struct { udpStream bool udpMTU int udpMTUTime time.Time - packetId uint16 + packetId atomic.Uint32 closeOnce sync.Once isServer bool defragger *udpDefragger @@ -202,11 +204,15 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) if !destination.IsValid() { return E.New("invalid destination address") } - c.packetId++ + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, - packetID: uint16(c.packetId), + packetID: uint16(packetId), fragmentTotal: 1, destination: destination, data: buffer, @@ -243,11 +249,15 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !destination.IsValid() { return 0, E.New("invalid destination address") } - c.packetId++ + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } message := allocMessage() *message = udpMessage{ sessionID: c.sessionID, - packetID: uint16(c.packetId), + packetID: uint16(packetId), fragmentTotal: 1, destination: destination, data: buf.As(p),