mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 12:27:37 +03:00
Fix shadowsocks
This commit is contained in:
parent
bc80c3357c
commit
ffb0ba34c3
2 changed files with 115 additions and 119 deletions
|
@ -168,9 +168,6 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
defer request.Release()
|
||||
|
||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
if c.method.replayFilter != nil {
|
||||
c.method.replayFilter.Check(request.Bytes())
|
||||
}
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
|
|
|
@ -39,11 +39,12 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrBadHeaderType = E.New("bad header type")
|
||||
ErrBadTimestamp = E.New("bad timestamp")
|
||||
ErrBadRequestSalt = E.New("bad request salt")
|
||||
ErrBadClientSessionId = E.New("bad client session id")
|
||||
ErrPacketIdNotUnique = E.New("packet id not unique")
|
||||
ErrBadHeaderType = E.New("bad header type")
|
||||
ErrBadTimestamp = E.New("bad timestamp")
|
||||
ErrBadRequestSalt = E.New("bad request salt")
|
||||
ErrBadClientSessionId = E.New("bad client session id")
|
||||
ErrPacketIdNotUnique = E.New("packet id not unique")
|
||||
ErrTooManyServerSessions = E.New("server session changed more than once during the last minute")
|
||||
)
|
||||
|
||||
var List = []string{
|
||||
|
@ -57,7 +58,7 @@ func New(method string, psk []byte, secureRNG io.Reader) (shadowsocks.Method, er
|
|||
name: method,
|
||||
key: psk,
|
||||
secureRNG: secureRNG,
|
||||
replayFilter: replay.NewCuckoo(30),
|
||||
replayFilter: replay.NewCuckoo(60),
|
||||
}
|
||||
|
||||
if len(psk) != KeySaltSize {
|
||||
|
@ -68,32 +69,26 @@ func New(method string, psk []byte, secureRNG io.Reader) (shadowsocks.Method, er
|
|||
case "2022-blake3-aes-128-gcm":
|
||||
m.keyLength = 16
|
||||
m.constructor = newAESGCM
|
||||
m.udpBlockConstructor = newAES
|
||||
m.udpBlockCipher = newAES(psk)
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
m.keyLength = 32
|
||||
m.constructor = newAESGCM
|
||||
m.udpBlockConstructor = newAES
|
||||
m.udpBlockCipher = newAES(psk)
|
||||
case "2022-blake3-chacha20-poly1305":
|
||||
m.keyLength = 32
|
||||
m.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
m.udpConstructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
m.constructor = newChacha20Poly1305
|
||||
m.udpCipher = newXChacha20Poly1305(psk)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func Blake3DeriveKey(secret, salt, outKey []byte) {
|
||||
func Blake3DeriveKey(secret, salt []byte, keyLength int) []byte {
|
||||
sessionKey := make([]byte, len(secret)+len(salt))
|
||||
copy(sessionKey, secret)
|
||||
copy(sessionKey[len(secret):], salt)
|
||||
outKey := make([]byte, keyLength)
|
||||
blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
|
||||
return outKey
|
||||
}
|
||||
|
||||
func newAES(key []byte) cipher.Block {
|
||||
|
@ -110,15 +105,27 @@ func newAESGCM(key []byte) cipher.AEAD {
|
|||
return aead
|
||||
}
|
||||
|
||||
func newChacha20Poly1305(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
|
||||
func newXChacha20Poly1305(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
udpBlockConstructor func(key []byte) cipher.Block
|
||||
udpConstructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
name string
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
udpCipher cipher.AEAD
|
||||
udpBlockCipher cipher.Block
|
||||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
|
@ -147,50 +154,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
|
|||
}
|
||||
|
||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &clientPacketConn{conn, m, newUDPSession()}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
if m.udpConstructor == nil {
|
||||
// aes
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
subKey := make([]byte, m.keyLength)
|
||||
Blake3DeriveKey(m.key, packetHeader[:8], subKey)
|
||||
|
||||
cipher := m.constructor(subKey)
|
||||
cipher.Seal(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
|
||||
buffer.Extend(cipher.Overhead())
|
||||
m.udpBlockConstructor(m.key).Encrypt(packetHeader, packetHeader)
|
||||
} else {
|
||||
// xchacha
|
||||
cipher := m.udpConstructor(m.key)
|
||||
cipher.Seal(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
buffer.Extend(cipher.Overhead())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
||||
if m.udpBlockConstructor != nil {
|
||||
if buffer.Len() <= aes.BlockSize {
|
||||
return E.New("insufficient data: ", buffer.Len())
|
||||
}
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
m.udpBlockConstructor(m.key).Decrypt(packetHeader, packetHeader)
|
||||
subKey := make([]byte, m.keyLength)
|
||||
Blake3DeriveKey(m.key, packetHeader[:8], subKey)
|
||||
_, err := m.constructor(subKey).Open(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
_, err := m.udpConstructor(m.key).Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
}
|
||||
return nil
|
||||
return &clientPacketConn{conn, m, m.newUDPSession()}
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
|
@ -214,12 +178,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
|
||||
salt := make([]byte, KeySaltSize)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
c.method.replayFilter.Check(salt)
|
||||
common.Must1(request.Write(salt))
|
||||
|
||||
subKey := make([]byte, c.method.keyLength)
|
||||
Blake3DeriveKey(c.method.key, salt, subKey)
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
|
@ -227,7 +187,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
}
|
||||
writer = shadowaead.NewWriter(
|
||||
writer,
|
||||
c.method.constructor(subKey),
|
||||
c.method.constructor(Blake3DeriveKey(c.method.key, salt, c.method.keyLength)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
|
@ -257,7 +217,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
return E.Cause(err, "write payload")
|
||||
}
|
||||
} else {
|
||||
pLen := rand.Intn(MaxPaddingLength)
|
||||
pLen := rand.Intn(MaxPaddingLength + 1)
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
|
@ -299,12 +259,9 @@ func (c *clientConn) readResponse() error {
|
|||
return E.New("salt is not unique")
|
||||
}
|
||||
|
||||
subKey := make([]byte, c.method.keyLength)
|
||||
Blake3DeriveKey(c.method.key, salt, subKey)
|
||||
|
||||
reader := shadowaead.NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(subKey),
|
||||
c.method.constructor(Blake3DeriveKey(c.method.key, salt, c.method.keyLength)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
|
@ -391,7 +348,7 @@ type clientPacketConn struct {
|
|||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
if c.method.udpConstructor != nil {
|
||||
if c.method.udpCipher != nil {
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
||||
}
|
||||
common.Must(
|
||||
|
@ -401,15 +358,19 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPo
|
|||
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(header, binary.BigEndian, uint16(0)), // padding length
|
||||
)
|
||||
c.session.filter.ValidateCounter(c.session.packetId, math.MaxUint64)
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
err = c.method.EncodePacket(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
if c.method.udpCipher != nil {
|
||||
c.method.udpCipher.Seal(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
buffer.Extend(c.method.udpCipher.Overhead())
|
||||
} else {
|
||||
packetHeader := buffer.To(aes.BlockSize)
|
||||
c.session.cipher.Seal(buffer.Index(aes.BlockSize), packetHeader[4:16], buffer.From(aes.BlockSize), nil)
|
||||
buffer.Extend(c.session.cipher.Overhead())
|
||||
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||
}
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
@ -421,51 +382,49 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
}
|
||||
buffer.Truncate(n)
|
||||
|
||||
err = c.method.DecodePacket(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var packetHeader []byte
|
||||
if c.method.udpCipher != nil {
|
||||
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decrypt packet")
|
||||
}
|
||||
buffer.Advance(PacketNonceSize)
|
||||
} else {
|
||||
packetHeader = buffer.To(aes.BlockSize)
|
||||
c.method.udpBlockCipher.Decrypt(packetHeader, packetHeader)
|
||||
}
|
||||
|
||||
var sessionId uint64
|
||||
var sessionId, packetId uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var isLastSessionId bool
|
||||
if c.session.remoteSessionId == 0 {
|
||||
c.session.remoteSessionId = sessionId
|
||||
} else if sessionId != c.session.remoteSessionId {
|
||||
if sessionId == c.session.lastRemoteSessionId {
|
||||
isLastSessionId = true
|
||||
} else {
|
||||
c.session.lastRemoteSessionId = c.session.remoteSessionId
|
||||
c.session.remoteSessionId = sessionId
|
||||
c.session.lastFilter = c.session.filter
|
||||
c.session.filter = new(wgReplay.Filter)
|
||||
}
|
||||
}
|
||||
|
||||
var packetId uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &packetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !isLastSessionId {
|
||||
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return nil, ErrPacketIdNotUnique
|
||||
|
||||
var remoteCipher cipher.AEAD
|
||||
if packetHeader != nil {
|
||||
if sessionId == c.session.remoteSessionId {
|
||||
remoteCipher = c.session.remoteCipher
|
||||
} else if sessionId == c.session.lastRemoteSessionId {
|
||||
remoteCipher = c.session.lastRemoteCipher
|
||||
} else {
|
||||
remoteCipher = c.method.constructor(Blake3DeriveKey(c.method.key, packetHeader[:8], c.method.keyLength))
|
||||
}
|
||||
} else {
|
||||
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return nil, ErrPacketIdNotUnique
|
||||
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decrypt packet")
|
||||
}
|
||||
}
|
||||
|
||||
headerType, err := buffer.ReadBytes(1)
|
||||
var headerType byte
|
||||
headerType, err = buffer.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headerType[0] != HeaderTypeServer {
|
||||
if headerType != HeaderTypeServer {
|
||||
return nil, ErrBadHeaderType
|
||||
}
|
||||
|
||||
|
@ -478,6 +437,35 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
return nil, ErrBadTimestamp
|
||||
}
|
||||
|
||||
if sessionId == c.session.remoteSessionId {
|
||||
if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return nil, ErrPacketIdNotUnique
|
||||
}
|
||||
c.session.remoteSeen = time.Now().Unix()
|
||||
} else if sessionId == c.session.lastRemoteSessionId {
|
||||
if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) {
|
||||
return nil, ErrPacketIdNotUnique
|
||||
}
|
||||
remoteCipher = c.session.lastRemoteCipher
|
||||
c.session.lastRemoteSeen = time.Now().Unix()
|
||||
} else {
|
||||
if c.session.remoteSessionId != 0 {
|
||||
if time.Now().Unix()-c.session.lastRemoteSeen < 60 {
|
||||
return nil, ErrTooManyServerSessions
|
||||
} else {
|
||||
c.session.lastRemoteSessionId = c.session.remoteSessionId
|
||||
c.session.lastFilter = c.session.filter
|
||||
c.session.lastRemoteSeen = c.session.remoteSeen
|
||||
c.session.lastRemoteCipher = c.session.remoteCipher
|
||||
c.session.filter = new(wgReplay.Filter)
|
||||
}
|
||||
}
|
||||
c.session.remoteSessionId = sessionId
|
||||
c.session.remoteCipher = remoteCipher
|
||||
c.session.remoteSeen = time.Now().Unix()
|
||||
c.session.filter.ValidateCounter(packetId, math.MaxUint64)
|
||||
}
|
||||
|
||||
var clientSessionId uint64
|
||||
err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
|
||||
if err != nil {
|
||||
|
@ -503,6 +491,11 @@ type udpSession struct {
|
|||
packetId uint64
|
||||
remoteSessionId uint64
|
||||
lastRemoteSessionId uint64
|
||||
remoteSeen int64
|
||||
lastRemoteSeen int64
|
||||
cipher cipher.AEAD
|
||||
remoteCipher cipher.AEAD
|
||||
lastRemoteCipher cipher.AEAD
|
||||
filter *wgReplay.Filter
|
||||
lastFilter *wgReplay.Filter
|
||||
}
|
||||
|
@ -511,9 +504,15 @@ func (s *udpSession) nextPacketId() uint64 {
|
|||
return atomic.AddUint64(&s.packetId, 1)
|
||||
}
|
||||
|
||||
func newUDPSession() *udpSession {
|
||||
return &udpSession{
|
||||
func (m *Method) newUDPSession() *udpSession {
|
||||
session := &udpSession{
|
||||
sessionId: rand.Uint64(),
|
||||
filter: new(wgReplay.Filter),
|
||||
}
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(sessionId, session.sessionId)
|
||||
session.cipher = m.constructor(Blake3DeriveKey(m.key, sessionId, m.keyLength))
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue