Refine udp salt check

This commit is contained in:
世界 2022-06-27 20:08:18 +08:00
parent 4fe3099239
commit 6d5e7fb635
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 90 additions and 72 deletions

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks
go 1.18
require (
github.com/sagernet/sing v0.0.0-20220619130320-8793fe5e067d
github.com/sagernet/sing v0.0.0-20220627092450-605697c1aec0
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
lukechampine.com/blake3 v1.1.7
)

4
go.sum
View file

@ -1,8 +1,8 @@
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/sagernet/sing v0.0.0-20220619130320-8793fe5e067d h1:zr8y4wmNIxv6Kkvgqysx8Piy82ATAThEj1jaEf23YQs=
github.com/sagernet/sing v0.0.0-20220619130320-8793fe5e067d/go.mod h1:I67R/q5f67xDExL2kL3RLIP7kGJBOPkYXkpRAykgC+E=
github.com/sagernet/sing v0.0.0-20220627092450-605697c1aec0 h1:WRc+FBhOM12FwVphxpRgPLcr9+9JmFLuDKIBtoSrvwk=
github.com/sagernet/sing v0.0.0-20220627092450-605697c1aec0/go.mod h1:I67R/q5f67xDExL2kL3RLIP7kGJBOPkYXkpRAykgC+E=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c h1:aFV+BgZ4svzjfabn8ERpuB4JI4N6/rdy1iusx77G3oU=

View file

@ -19,7 +19,6 @@ import (
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022/wg_replay"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -582,6 +581,16 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return M.Socksaddr{}, err
}
if sessionId == c.session.remoteSessionId {
if !c.session.window.Check(packetId) {
return M.Socksaddr{}, ErrPacketIdNotUnique
}
} else if sessionId == c.session.lastRemoteSessionId {
if !c.session.lastWindow.Check(packetId) {
return M.Socksaddr{}, ErrPacketIdNotUnique
}
}
var remoteCipher cipher.AEAD
if packetHeader != nil {
if sessionId == c.session.remoteSessionId {
@ -624,13 +633,9 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
}
if sessionId == c.session.remoteSessionId {
if !c.session.filter.ValidateCounter(packetId) {
return M.Socksaddr{}, ErrPacketIdNotUnique
}
c.session.window.Add(packetId)
} else if sessionId == c.session.lastRemoteSessionId {
if !c.session.lastFilter.ValidateCounter(packetId) {
return M.Socksaddr{}, ErrPacketIdNotUnique
}
c.session.lastWindow.Add(packetId)
c.session.lastRemoteSeen = time.Now().Unix()
} else {
if c.session.remoteSessionId != 0 {
@ -638,15 +643,15 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return M.Socksaddr{}, ErrTooManyServerSessions
} else {
c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.lastFilter = c.session.filter
c.session.lastWindow = c.session.window
c.session.lastRemoteSeen = time.Now().Unix()
c.session.lastRemoteCipher = c.session.remoteCipher
c.session.filter = wg_replay.Filter{}
c.session.window = SlidingWindow{}
}
}
c.session.remoteSessionId = sessionId
c.session.remoteCipher = remoteCipher
c.session.filter.ValidateCounter(packetId)
c.session.window.Add(packetId)
}
var clientSessionId uint64
@ -786,8 +791,8 @@ type udpSession struct {
cipher cipher.AEAD
remoteCipher cipher.AEAD
lastRemoteCipher cipher.AEAD
filter wg_replay.Filter
lastFilter wg_replay.Filter
window SlidingWindow
lastWindow SlidingWindow
rng io.Reader
}

View file

@ -18,7 +18,6 @@ import (
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022/wg_replay"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -411,7 +410,7 @@ returnErr:
return err
process:
if !session.filter.ValidateCounter(packetId) {
if !session.window.Check(packetId) {
err = ErrPacketIdNotUnique
goto returnErr
}
@ -425,6 +424,8 @@ process:
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
}
session.window.Add(packetId)
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
@ -547,7 +548,7 @@ type serverUDPSession struct {
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD
filter wg_replay.Filter
window SlidingWindow
rng io.Reader
}

View file

@ -314,7 +314,7 @@ returnErr:
return err
process:
if !session.filter.ValidateCounter(packetId) {
if !session.window.Check(packetId) {
err = ErrPacketIdNotUnique
goto returnErr
}
@ -328,6 +328,8 @@ process:
buffer.Truncate(buffer.Len() - shadowaead.Overhead)
}
session.window.Add(packetId)
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {

View file

@ -0,0 +1,63 @@
package shadowaead_2022
const (
swBlockBitLog = 6 // 1<<6 == 64 bits
swBlockBits = 1 << swBlockBitLog // must be power of 2
swRingBlocks = 1 << 7 // must be power of 2
swBlockMask = swRingBlocks - 1
swBitMask = swBlockBits - 1
swSize = (swRingBlocks - 1) * swBlockBits
)
// SlidingWindow maintains a sliding window of uint64 counters.
type SlidingWindow struct {
last uint64
ring [swRingBlocks]uint64
}
// Reset resets the filter to its initial state.
func (f *SlidingWindow) Reset() {
f.last = 0
f.ring[0] = 0
}
// Check checks whether counter can be accepted by the sliding window filter.
func (f *SlidingWindow) Check(counter uint64) bool {
switch {
case counter > f.last: // ahead of window
return true
case f.last-counter > swSize: // behind window
return false
}
// In window. Check bit.
blockIndex := counter >> swBlockBitLog & swBlockMask
bitIndex := counter & swBitMask
return f.ring[blockIndex]>>bitIndex&1 == 0
}
// Add adds counter to the sliding window without checking if the counter is valid.
// Call Check beforehand to make sure the counter is valid.
func (f *SlidingWindow) Add(counter uint64) {
blockIndex := counter >> swBlockBitLog
// Check if counter is ahead of window.
if counter > f.last {
lastBlockIndex := f.last >> swBlockBitLog
diff := int(blockIndex - lastBlockIndex)
if diff > swRingBlocks {
diff = swRingBlocks
}
for i := 0; i < diff; i++ {
lastBlockIndex = (lastBlockIndex + 1) & swBlockMask
f.ring[lastBlockIndex] = 0
}
f.last = counter
}
blockIndex &= swBlockMask
bitIndex := counter & swBitMask
f.ring[blockIndex] |= 1 << bitIndex
}

View file

@ -1,53 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
// Package wg_replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package wg_replay
type block uint64
const (
blockBitLog = 6 // 1<<6 == 64 bits
blockBits = 1 << blockBitLog // must be power of 2
ringBlocks = 1 << 7 // must be power of 2
windowSize = (ringBlocks - 1) * blockBits
blockMask = ringBlocks - 1
bitMask = blockBits - 1
)
// A Filter rejects replayed messages by checking if message counter value is
// within a sliding window of previously received messages.
// The zero value for Filter is an empty filter ready to use.
// Filters are unsafe for concurrent use.
type Filter struct {
last uint64
ring [ringBlocks]block
}
// ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected.
func (f *Filter) ValidateCounter(counter uint64) bool {
indexBlock := counter >> blockBitLog
if counter > f.last { // move window forward
current := f.last >> blockBitLog
diff := indexBlock - current
if diff > ringBlocks {
diff = ringBlocks // cap diff to clear the whole ring
}
for i := current + 1; i <= current+diff; i++ {
f.ring[i&blockMask] = 0
}
f.last = counter
} else if f.last-counter > windowSize { // behind current window
return false
}
// check and set bit
indexBlock &= blockMask
indexBit := counter & bitMask
old := f.ring[indexBlock]
new := old | 1<<indexBit
f.ring[indexBlock] = new
return old != new
}