mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 20:37:40 +03:00
Export std compatible shadowsocks packet conn
This commit is contained in:
parent
f1744155fc
commit
f720169354
6 changed files with 179 additions and 147 deletions
|
@ -32,6 +32,11 @@ type PacketConn interface {
|
||||||
SetWriteDeadline(t time.Time) error
|
SetWriteDeadline(t time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NetPacketConn interface {
|
||||||
|
net.PacketConn
|
||||||
|
PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
type UDPHandler interface {
|
type UDPHandler interface {
|
||||||
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
NewPacket(conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/rw"
|
||||||
"github.com/sagernet/sing/common/udpnat"
|
"github.com/sagernet/sing/common/udpnat"
|
||||||
"github.com/sagernet/sing/protocol/socks5"
|
"github.com/sagernet/sing/protocol/socks5"
|
||||||
)
|
)
|
||||||
|
@ -48,7 +49,7 @@ func (m *NoneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *NoneMethod) DialPacketConn(conn net.Conn) N.PacketConn {
|
func (m *NoneMethod) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
||||||
return &nonePacketConn{conn}
|
return &nonePacketConn{conn}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,14 +118,13 @@ direct:
|
||||||
|
|
||||||
func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
if !c.handshake {
|
if !c.handshake {
|
||||||
panic("missing client handshake")
|
return rw.ReadFrom0(c, r)
|
||||||
}
|
}
|
||||||
return c.Conn.(io.ReaderFrom).ReadFrom(r)
|
return c.Conn.(io.ReaderFrom).ReadFrom(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
return io.Copy(w, c.Conn)
|
return io.Copy(w, c.Conn)
|
||||||
// return c.Conn.(io.WriterTo).WriteTo(w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *noneConn) RemoteAddr() net.Addr {
|
func (c *noneConn) RemoteAddr() net.Addr {
|
||||||
|
@ -144,26 +144,43 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
headerLen := socks5.AddressSerializer.AddrPortLen(destination)
|
header := buf.With(buffer.ExtendHeader(socks5.AddressSerializer.AddrPortLen(destination)))
|
||||||
var header *buf.Buffer
|
|
||||||
var writeHeader bool
|
|
||||||
if buffer.Start() >= headerLen {
|
|
||||||
header = buf.With(buffer.ExtendHeader(headerLen))
|
|
||||||
} else {
|
|
||||||
_buffer := buf.StackNewSize(buffer.Len() + headerLen)
|
|
||||||
defer runtime.KeepAlive(_buffer)
|
|
||||||
header = common.Dup(_buffer)
|
|
||||||
writeHeader = true
|
|
||||||
}
|
|
||||||
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
|
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if writeHeader {
|
return common.Error(buffer.WriteTo(c))
|
||||||
return common.Error(header.WriteTo(c))
|
}
|
||||||
} else {
|
|
||||||
return common.Error(buffer.WriteTo(c))
|
func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
n, err = c.Read(p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
buffer := buf.With(p[:n])
|
||||||
|
destination, err := socks5.AddressSerializer.ReadAddrPort(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr = destination.UDPAddr()
|
||||||
|
n = copy(p, buffer.Bytes())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
destination := M.SocksaddrFromNet(addr)
|
||||||
|
_buffer := buf.Make(socks5.AddressSerializer.AddrPortLen(destination) + len(p))
|
||||||
|
defer runtime.KeepAlive(_buffer)
|
||||||
|
buffer := buf.With(common.Dup(_buffer))
|
||||||
|
err = socks5.AddressSerializer.WriteAddrPort(buffer, destination)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = buffer.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type NoneService struct {
|
type NoneService struct {
|
||||||
|
|
|
@ -16,7 +16,7 @@ type Method interface {
|
||||||
KeyLength() int
|
KeyLength() int
|
||||||
DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error)
|
DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error)
|
||||||
DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
|
DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
|
||||||
DialPacketConn(conn net.Conn) N.PacketConn
|
DialPacketConn(conn net.Conn) N.NetPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func Key(password []byte, keySize int) []byte {
|
func Key(password []byte, keySize int) []byte {
|
||||||
|
|
|
@ -162,7 +162,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
|
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
||||||
return &clientPacketConn{m, conn}
|
return &clientPacketConn{m, conn}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,6 +365,48 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
return socks5.AddressSerializer.ReadAddrPort(buffer)
|
return socks5.AddressSerializer.ReadAddrPort(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
n, err = c.Read(p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b := buf.With(p[:n])
|
||||||
|
err = c.DecodePacket(b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
destination, err := socks5.AddressSerializer.ReadAddrPort(b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr = destination.UDPAddr()
|
||||||
|
n = copy(p, b.Bytes())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
_buffer := buf.StackNew()
|
||||||
|
defer runtime.KeepAlive(_buffer)
|
||||||
|
buffer := common.Dup(_buffer)
|
||||||
|
err = socks5.AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = buffer.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = c.EncodePacket(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = c.Write(buffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *clientPacketConn) UpstreamReader() io.Reader {
|
func (c *clientPacketConn) UpstreamReader() io.Reader {
|
||||||
return c.Conn
|
return c.Conn
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,7 @@ import (
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/debug"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/log"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/replay"
|
"github.com/sagernet/sing/common/replay"
|
||||||
|
@ -51,8 +49,6 @@ const (
|
||||||
overhead = 16
|
overhead = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewLogger("shadowaead (2022)")
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrBadHeaderType = E.New("bad header type")
|
ErrBadHeaderType = E.New("bad header type")
|
||||||
ErrBadTimestamp = E.New("bad timestamp")
|
ErrBadTimestamp = E.New("bad timestamp")
|
||||||
|
@ -182,7 +178,7 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
|
func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
||||||
return &clientPacketConn{conn, m, m.newUDPSession()}
|
return &clientPacketConn{conn, m, m.newUDPSession()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,14 +202,7 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
||||||
if pskLen < 2 {
|
if pskLen < 2 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin extended identity headers")
|
|
||||||
}
|
|
||||||
for i, psk := range m.pskList {
|
for i, psk := range m.pskList {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("current psk ", buf.EncodeHexString(psk[:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
keyMaterial := make([]byte, 2*KeySaltSize)
|
keyMaterial := make([]byte, 2*KeySaltSize)
|
||||||
copy(keyMaterial, psk[:])
|
copy(keyMaterial, psk[:])
|
||||||
copy(keyMaterial[KeySaltSize:], salt)
|
copy(keyMaterial[KeySaltSize:], salt)
|
||||||
|
@ -221,35 +210,18 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
||||||
identitySubkey := common.Dup(_identitySubkey)
|
identitySubkey := common.Dup(_identitySubkey)
|
||||||
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("identity subkey ", buf.EncodeHexString(identitySubkey))
|
|
||||||
}
|
|
||||||
|
|
||||||
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("next psk hash ", buf.EncodeHexString(pskHash))
|
|
||||||
}
|
|
||||||
|
|
||||||
header := request.Extend(16)
|
header := request.Extend(16)
|
||||||
m.blockConstructor(identitySubkey).Encrypt(header, pskHash)
|
m.blockConstructor(identitySubkey).Encrypt(header, pskHash)
|
||||||
runtime.KeepAlive(_identitySubkey)
|
runtime.KeepAlive(_identitySubkey)
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("encoded ", buf.EncodeHexString(header))
|
|
||||||
}
|
|
||||||
if i == pskLen-2 {
|
if i == pskLen-2 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("ended extended identity headers")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin client handshake request")
|
|
||||||
}
|
|
||||||
|
|
||||||
salt := make([]byte, KeySaltSize)
|
salt := make([]byte, KeySaltSize)
|
||||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||||
|
|
||||||
|
@ -276,10 +248,6 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(payload) > 0 {
|
if len(payload) > 0 {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("write payload ", len(payload), "B")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
|
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding length")
|
return E.Cause(err, "write padding length")
|
||||||
|
@ -290,9 +258,6 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pLen := rand.Intn(MaxPaddingLength + 1)
|
pLen := rand.Intn(MaxPaddingLength + 1)
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("write padding ", pLen, "B")
|
|
||||||
}
|
|
||||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
|
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write padding length")
|
return E.Cause(err, "write padding length")
|
||||||
|
@ -308,10 +273,6 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
return E.Cause(err, "client handshake")
|
return E.Cause(err, "client handshake")
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("ended client handshake request")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.requestSalt = salt
|
c.requestSalt = salt
|
||||||
c.writer = writer
|
c.writer = writer
|
||||||
return nil
|
return nil
|
||||||
|
@ -329,10 +290,6 @@ func (c *clientConn) readResponse() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin client handshake response")
|
|
||||||
}
|
|
||||||
|
|
||||||
_salt := make([]byte, KeySaltSize)
|
_salt := make([]byte, KeySaltSize)
|
||||||
salt := common.Dup(_salt)
|
salt := common.Dup(_salt)
|
||||||
_, err := io.ReadFull(c.Conn, salt)
|
_, err := io.ReadFull(c.Conn, salt)
|
||||||
|
@ -341,7 +298,6 @@ func (c *clientConn) readResponse() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.method.replayFilter.Check(salt) {
|
if !c.method.replayFilter.Check(salt) {
|
||||||
logger.Warn("read salt ", buf.EncodeHexString(salt))
|
|
||||||
return E.New("salt not unique")
|
return E.New("salt not unique")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,9 +326,6 @@ func (c *clientConn) readResponse() error {
|
||||||
|
|
||||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||||
if diff > 30 {
|
if diff > 30 {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("server timestamp ", time.Unix(int64(epoch), 0).String(), " diff ", diff)
|
|
||||||
}
|
|
||||||
return ErrBadTimestamp
|
return ErrBadTimestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -384,9 +337,6 @@ func (c *clientConn) readResponse() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if bytes.Compare(requestSalt, c.requestSalt) > 0 {
|
if bytes.Compare(requestSalt, c.requestSalt) > 0 {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("request salt ", buf.EncodeHexString(requestSalt))
|
|
||||||
}
|
|
||||||
return ErrBadRequestSalt
|
return ErrBadRequestSalt
|
||||||
}
|
}
|
||||||
runtime.KeepAlive(_requestSalt)
|
runtime.KeepAlive(_requestSalt)
|
||||||
|
@ -394,9 +344,6 @@ func (c *clientConn) readResponse() error {
|
||||||
c.requestSalt = nil
|
c.requestSalt = nil
|
||||||
c.reader = reader
|
c.reader = reader
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("ended client handshake response")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -473,10 +420,6 @@ type clientPacketConn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin client packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
var hdrLen int
|
var hdrLen int
|
||||||
if c.method.udpCipher != nil {
|
if c.method.udpCipher != nil {
|
||||||
hdrLen = PacketNonceSize
|
hdrLen = PacketNonceSize
|
||||||
|
@ -503,49 +446,26 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
||||||
dataIndex = aes.BlockSize
|
dataIndex = aes.BlockSize
|
||||||
}
|
}
|
||||||
|
|
||||||
packetId := c.session.nextPacketId()
|
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("session id ", c.session.sessionId)
|
|
||||||
logger.Trace("packet id ", packetId)
|
|
||||||
}
|
|
||||||
|
|
||||||
common.Must(
|
common.Must(
|
||||||
binary.Write(header, binary.BigEndian, c.session.sessionId),
|
binary.Write(header, binary.BigEndian, c.session.sessionId),
|
||||||
binary.Write(header, binary.BigEndian, packetId),
|
binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.method.udpCipher == nil && pskLen > 1 {
|
if c.method.udpCipher == nil && pskLen > 1 {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin extended identity headers")
|
|
||||||
}
|
|
||||||
for i, psk := range c.method.pskList {
|
for i, psk := range c.method.pskList {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("current psk ", buf.EncodeHexString(psk[:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
dataIndex += aes.BlockSize
|
dataIndex += aes.BlockSize
|
||||||
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("next psk hash ", buf.EncodeHexString(pskHash))
|
|
||||||
}
|
|
||||||
|
|
||||||
identityHeader := header.Extend(aes.BlockSize)
|
identityHeader := header.Extend(aes.BlockSize)
|
||||||
for textI := 0; textI < aes.BlockSize; textI++ {
|
for textI := 0; textI < aes.BlockSize; textI++ {
|
||||||
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
|
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
|
||||||
}
|
}
|
||||||
c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader)
|
c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader)
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("encoded ", buf.EncodeHexString(identityHeader))
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == pskLen-2 {
|
if i == pskLen-2 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("end extended identity headers")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
common.Must(
|
common.Must(
|
||||||
header.WriteByte(HeaderTypeClient),
|
header.WriteByte(HeaderTypeClient),
|
||||||
|
@ -568,9 +488,6 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
||||||
buffer.Extend(c.session.cipher.Overhead())
|
buffer.Extend(c.session.cipher.Overhead())
|
||||||
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("ended client packet")
|
|
||||||
}
|
|
||||||
return common.Error(c.Write(buffer.Bytes()))
|
return common.Error(c.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,10 +498,6 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
}
|
}
|
||||||
buffer.Truncate(n)
|
buffer.Truncate(n)
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin server packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
var packetHeader []byte
|
var packetHeader []byte
|
||||||
if c.method.udpCipher != nil {
|
if c.method.udpCipher != nil {
|
||||||
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
|
||||||
|
@ -607,11 +520,6 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
return M.Socksaddr{}, err
|
return M.Socksaddr{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("session id ", sessionId)
|
|
||||||
logger.Trace("packet id ", packetId)
|
|
||||||
}
|
|
||||||
|
|
||||||
var remoteCipher cipher.AEAD
|
var remoteCipher cipher.AEAD
|
||||||
if packetHeader != nil {
|
if packetHeader != nil {
|
||||||
if sessionId == c.session.remoteSessionId {
|
if sessionId == c.session.remoteSessionId {
|
||||||
|
@ -646,9 +554,6 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
|
|
||||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||||
if diff > 30 {
|
if diff > 30 {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("server timestamp ", time.Unix(int64(epoch), 0).String(), " diff ", diff)
|
|
||||||
}
|
|
||||||
return M.Socksaddr{}, ErrBadTimestamp
|
return M.Socksaddr{}, ErrBadTimestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -700,12 +605,102 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return M.Socksaddr{}, err
|
return M.Socksaddr{}, err
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("ended client packet")
|
|
||||||
}
|
|
||||||
return destination, nil
|
return destination, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
buffer := buf.With(p)
|
||||||
|
destination, err := c.ReadPacket(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr = destination.UDPAddr()
|
||||||
|
n = copy(p, buffer.Bytes())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
destination := M.SocksaddrFromNet(addr)
|
||||||
|
var overHead int
|
||||||
|
if c.method.udpCipher != nil {
|
||||||
|
overHead = PacketNonceSize + c.method.udpCipher.Overhead()
|
||||||
|
} else {
|
||||||
|
overHead = c.session.cipher.Overhead()
|
||||||
|
}
|
||||||
|
overHead += 16 // packet header
|
||||||
|
pskLen := len(c.method.pskList)
|
||||||
|
if c.method.udpCipher == nil && pskLen > 1 {
|
||||||
|
overHead += (pskLen - 1) * aes.BlockSize
|
||||||
|
}
|
||||||
|
overHead += 1 // header type
|
||||||
|
overHead += 8 // timestamp
|
||||||
|
overHead += 2 // padding length
|
||||||
|
overHead += socks5.AddressSerializer.AddrPortLen(destination)
|
||||||
|
|
||||||
|
_buffer := buf.Make(overHead + len(p))
|
||||||
|
defer runtime.KeepAlive(_buffer)
|
||||||
|
buffer := buf.With(common.Dup(_buffer))
|
||||||
|
|
||||||
|
var dataIndex int
|
||||||
|
if c.method.udpCipher != nil {
|
||||||
|
common.Must1(buffer.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
||||||
|
if pskLen > 1 {
|
||||||
|
panic("unsupported chacha extended header")
|
||||||
|
}
|
||||||
|
dataIndex = buffer.Len()
|
||||||
|
} else {
|
||||||
|
dataIndex = aes.BlockSize
|
||||||
|
}
|
||||||
|
|
||||||
|
common.Must(
|
||||||
|
binary.Write(buffer, binary.BigEndian, c.session.sessionId),
|
||||||
|
binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.method.udpCipher == nil && pskLen > 1 {
|
||||||
|
for i, psk := range c.method.pskList {
|
||||||
|
dataIndex += aes.BlockSize
|
||||||
|
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||||
|
|
||||||
|
identityHeader := buffer.Extend(aes.BlockSize)
|
||||||
|
for textI := 0; textI < aes.BlockSize; textI++ {
|
||||||
|
identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI)
|
||||||
|
}
|
||||||
|
c.method.blockConstructor(psk[:]).Encrypt(identityHeader, identityHeader)
|
||||||
|
|
||||||
|
if i == pskLen-2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common.Must(
|
||||||
|
buffer.WriteByte(HeaderTypeClient),
|
||||||
|
binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())),
|
||||||
|
binary.Write(buffer, binary.BigEndian, uint16(0)), // padding length
|
||||||
|
)
|
||||||
|
err = socks5.AddressSerializer.WriteAddrPort(buffer, destination)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.method.udpCipher != nil {
|
||||||
|
c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
|
||||||
|
buffer.Extend(c.method.udpCipher.Overhead())
|
||||||
|
} else {
|
||||||
|
packetHeader := buffer.To(aes.BlockSize)
|
||||||
|
c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
|
||||||
|
buffer.Extend(c.session.cipher.Overhead())
|
||||||
|
c.method.udpBlockCipher.Encrypt(packetHeader, packetHeader)
|
||||||
|
}
|
||||||
|
err = common.Error(c.Write(buffer.Bytes()))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
type udpSession struct {
|
type udpSession struct {
|
||||||
headerType byte
|
headerType byte
|
||||||
sessionId uint64
|
sessionId uint64
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/cache"
|
"github.com/sagernet/sing/common/cache"
|
||||||
"github.com/sagernet/sing/common/debug"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
@ -87,20 +86,12 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("begin server handshake request")
|
|
||||||
}
|
|
||||||
|
|
||||||
requestSalt := make([]byte, KeySaltSize)
|
requestSalt := make([]byte, KeySaltSize)
|
||||||
_, err := io.ReadFull(conn, requestSalt)
|
_, err := io.ReadFull(conn, requestSalt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read request salt")
|
return E.Cause(err, "read request salt")
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("read salt ", buf.EncodeHexString(requestSalt))
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.replayFilter.Check(requestSalt) {
|
if !s.replayFilter.Check(requestSalt) {
|
||||||
return E.New("salt not unique")
|
return E.New("salt not unique")
|
||||||
}
|
}
|
||||||
|
@ -127,9 +118,6 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read timestamp")
|
return E.Cause(err, "read timestamp")
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("client timestamp ", time.Unix(int64(epoch), 0).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
|
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
|
||||||
return ErrBadTimestamp
|
return ErrBadTimestamp
|
||||||
|
@ -139,9 +127,6 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read destination")
|
return E.Cause(err, "read destination")
|
||||||
}
|
}
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("destination ", destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
var paddingLen uint16
|
var paddingLen uint16
|
||||||
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||||
|
@ -149,14 +134,6 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
return E.Cause(err, "read padding length")
|
return E.Cause(err, "read padding length")
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
if paddingLen > 0 {
|
|
||||||
logger.Trace("padding ", paddingLen, "B")
|
|
||||||
} else {
|
|
||||||
logger.Trace("no padding")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if paddingLen > 0 {
|
if paddingLen > 0 {
|
||||||
err = reader.Discard(int(paddingLen))
|
err = reader.Discard(int(paddingLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -164,10 +141,6 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug.Enabled {
|
|
||||||
logger.Trace("end server handshake request")
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata.Protocol = "shadowsocks"
|
metadata.Protocol = "shadowsocks"
|
||||||
metadata.Destination = destination
|
metadata.Destination = destination
|
||||||
return s.handler.NewConnection(ctx, &serverConn{
|
return s.handler.NewConnection(ctx, &serverConn{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue