mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 12:27:37 +03:00
Refine buffer
This commit is contained in:
parent
603c62165e
commit
63ef20617a
13 changed files with 179 additions and 186 deletions
|
@ -141,6 +141,7 @@ func checkUpdate() {
|
||||||
Name: domain,
|
Name: domain,
|
||||||
Content: content,
|
Content: content,
|
||||||
Proxied: &overProxy,
|
Proxied: &overProxy,
|
||||||
|
TTL: 60,
|
||||||
}
|
}
|
||||||
if addr.Is4() {
|
if addr.Is4() {
|
||||||
record.Type = "A"
|
record.Type = "A"
|
||||||
|
|
|
@ -192,7 +192,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
||||||
if len(pskList) > 1 {
|
if len(pskList) > 1 {
|
||||||
return nil, shadowaead.ErrBadKey
|
return nil, shadowaead.ErrBadKey
|
||||||
}
|
}
|
||||||
method, err := shadowaead.New(f.Method, pskList[0], []byte(f.Password), rng, false)
|
var key []byte
|
||||||
|
if len(pskList) > 0 {
|
||||||
|
key = pskList[0]
|
||||||
|
}
|
||||||
|
method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -314,27 +318,21 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
var serverConn net.Conn
|
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
||||||
payload := buf.New()
|
if err != nil {
|
||||||
err := task.Run(ctx, func() error {
|
return E.Cause(err, "connect to server")
|
||||||
sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
}
|
||||||
serverConn = sc
|
_payload := buf.StackNew()
|
||||||
if err != nil {
|
payload := common.Dup(_payload)
|
||||||
return E.Cause(err, "connect to server")
|
err = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||||
}
|
if err != nil {
|
||||||
return nil
|
|
||||||
}, func() error {
|
|
||||||
err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = payload.ReadFrom(conn)
|
|
||||||
if err != nil && !E.IsTimeout(err) {
|
|
||||||
return E.Cause(err, "read payload")
|
|
||||||
}
|
|
||||||
err = conn.SetReadDeadline(time.Time{})
|
|
||||||
return err
|
return err
|
||||||
})
|
}
|
||||||
|
_, err = payload.ReadFrom(conn)
|
||||||
|
if err != nil && !E.IsTimeout(err) {
|
||||||
|
return E.Cause(err, "read payload")
|
||||||
|
}
|
||||||
|
err = conn.SetReadDeadline(time.Time{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
payload.Release()
|
payload.Release()
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -25,31 +25,15 @@ func New() *Buffer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSize(size int) *Buffer {
|
func StackNew() *Buffer {
|
||||||
if size <= 128 || size > BufferSize {
|
|
||||||
return &Buffer{
|
|
||||||
data: make([]byte, size),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &Buffer{
|
return &Buffer{
|
||||||
data: GetBytes(),
|
data: make([]byte, BufferSize),
|
||||||
start: ReversedHeader,
|
|
||||||
end: ReversedHeader,
|
|
||||||
managed: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FullNew() *Buffer {
|
func StackNewSize(size int) *Buffer {
|
||||||
return &Buffer{
|
return &Buffer{
|
||||||
data: GetBytes(),
|
data: Make(size),
|
||||||
managed: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func StackNew() Buffer {
|
|
||||||
return Buffer{
|
|
||||||
data: GetBytes(),
|
|
||||||
managed: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,20 +55,6 @@ func As(data []byte) *Buffer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Or(data []byte, size int) *Buffer {
|
|
||||||
max := cap(data)
|
|
||||||
if size != max {
|
|
||||||
data = data[:max]
|
|
||||||
}
|
|
||||||
if cap(data) >= size {
|
|
||||||
return &Buffer{
|
|
||||||
data: data,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return NewSize(size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func With(data []byte) *Buffer {
|
func With(data []byte) *Buffer {
|
||||||
return &Buffer{
|
return &Buffer{
|
||||||
data: data,
|
data: data,
|
||||||
|
|
|
@ -10,29 +10,41 @@ const (
|
||||||
|
|
||||||
var pool = sync.Pool{
|
var pool = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
var buffer [BufferSize]byte
|
buffer := make([]byte, BufferSize)
|
||||||
return buffer[:]
|
return &buffer
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBytes() []byte {
|
func GetBytes() []byte {
|
||||||
return pool.Get().([]byte)
|
return *pool.Get().(*[]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PutBytes(buffer []byte) {
|
func PutBytes(buffer []byte) {
|
||||||
pool.Put(buffer)
|
pool.Put(&buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Make(size int) []byte {
|
func Make(size int) []byte {
|
||||||
var buffer []byte
|
var buffer []byte
|
||||||
if size <= 64 {
|
if size <= 16 {
|
||||||
|
buffer = make([]byte, 16)
|
||||||
|
} else if size <= 32 {
|
||||||
|
buffer = make([]byte, 32)
|
||||||
|
} else if size <= 64 {
|
||||||
buffer = make([]byte, 64)
|
buffer = make([]byte, 64)
|
||||||
|
} else if size <= 128 {
|
||||||
|
buffer = make([]byte, 128)
|
||||||
|
} else if size <= 256 {
|
||||||
|
buffer = make([]byte, 256)
|
||||||
|
} else if size <= 512 {
|
||||||
|
buffer = make([]byte, 512)
|
||||||
} else if size <= 1024 {
|
} else if size <= 1024 {
|
||||||
buffer = make([]byte, 1024)
|
buffer = make([]byte, 1024)
|
||||||
} else if size <= 4096 {
|
} else if size <= 4*1024 {
|
||||||
buffer = make([]byte, 4096)
|
buffer = make([]byte, 4*1024)
|
||||||
} else if size <= 16384 {
|
} else if size <= 16*1024 {
|
||||||
buffer = make([]byte, 16384)
|
buffer = make([]byte, 16*1024)
|
||||||
|
} else if size <= 20*1024 {
|
||||||
|
buffer = make([]byte, 20*1024)
|
||||||
} else if size <= 65535 {
|
} else if size <= 65535 {
|
||||||
buffer = make([]byte, 65535)
|
buffer = make([]byte, 65535)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
package lowmem
|
package lowmem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Enabled = false
|
var Enabled = false
|
||||||
|
|
||||||
func Free() {
|
func Free() {
|
||||||
if Enabled {
|
if Enabled {
|
||||||
runtime.GC()
|
debug.FreeOSMemory()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,8 +57,8 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||||
|
|
||||||
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
|
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
|
||||||
return task.Run(ctx, func() error {
|
return task.Run(ctx, func() error {
|
||||||
buffer := buf.FullNew()
|
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
|
||||||
defer buffer.Release()
|
buffer := common.Dup(_buffer)
|
||||||
for {
|
for {
|
||||||
n, addr, err := conn.ReadFrom(buffer.FreeBytes())
|
n, addr, err := conn.ReadFrom(buffer.FreeBytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -72,8 +72,8 @@ func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.
|
||||||
buffer.FullReset()
|
buffer.FullReset()
|
||||||
}
|
}
|
||||||
}, func() error {
|
}, func() error {
|
||||||
buffer := buf.FullNew()
|
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
|
||||||
defer buffer.Release()
|
buffer := common.Dup(_buffer)
|
||||||
for {
|
for {
|
||||||
n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes())
|
n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
11
common/rw/writev_posix.go
Normal file
11
common/rw/writev_posix.go
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package rw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WriteV(fd uintptr, data ...[]byte) (int, error) {
|
||||||
|
return unix.Writev(int(fd), data)
|
||||||
|
}
|
16
common/rw/writev_windows.go
Normal file
16
common/rw/writev_windows.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package rw
|
||||||
|
|
||||||
|
import "golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
func WriteV(fd uintptr, data ...[]byte) (int, error) {
|
||||||
|
var n uint32
|
||||||
|
buffers := make([]*windows.WSABuf, len(data))
|
||||||
|
for i, buf := range data {
|
||||||
|
buffers[i] = &windows.WSABuf{
|
||||||
|
Len: uint32(len(buf)),
|
||||||
|
Buf: &buf[0],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := windows.WSASend(windows.Handle(fd), buffers[0], uint32(len(buffers)), &n, 0, nil, nil)
|
||||||
|
return int(n), err
|
||||||
|
}
|
|
@ -40,8 +40,8 @@ func (c *ServerConn) RemoteAddr() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ServerConn) loopInput() {
|
func (c *ServerConn) loopInput() {
|
||||||
buffer := buf.New()
|
_buffer := buf.StackNew()
|
||||||
defer buffer.Release()
|
buffer := common.Dup(_buffer)
|
||||||
for {
|
for {
|
||||||
destination, err := AddrParser.ReadAddrPort(c.inputReader)
|
destination, err := AddrParser.ReadAddrPort(c.inputReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -73,8 +73,8 @@ func (c *ServerConn) loopInput() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ServerConn) loopOutput() {
|
func (c *ServerConn) loopOutput() {
|
||||||
buffer := buf.New()
|
_buffer := buf.StackNew()
|
||||||
defer buffer.Release()
|
buffer := common.Dup(_buffer)
|
||||||
for {
|
for {
|
||||||
buffer.FullReset()
|
buffer.FullReset()
|
||||||
n, addr, err := buffer.ReadPacketFrom(c)
|
n, addr, err := buffer.ReadPacketFrom(c)
|
||||||
|
|
|
@ -29,8 +29,8 @@ func TestServerConn(t *testing.T) {
|
||||||
IP: net.IPv4(8, 8, 8, 8),
|
IP: net.IPv4(8, 8, 8, 8),
|
||||||
Port: 53,
|
Port: 53,
|
||||||
}))
|
}))
|
||||||
buffer := buf.New()
|
_buffer := buf.StackNew()
|
||||||
defer buffer.Release()
|
buffer := common.Dup(_buffer)
|
||||||
common.Must2(buffer.ReadPacketFrom(clientConn))
|
common.Must2(buffer.ReadPacketFrom(clientConn))
|
||||||
common.Must(message.Unpack(buffer.Bytes()))
|
common.Must(message.Unpack(buffer.Bytes()))
|
||||||
for _, answer := range message.Answers {
|
for _, answer := range message.Answers {
|
||||||
|
|
|
@ -81,8 +81,8 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
|
||||||
return 0, c.clientHandshake()
|
return 0, c.clientHandshake()
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer := buf.New()
|
_buffer := buf.StackNew()
|
||||||
defer buffer.Release()
|
buffer := common.Dup(_buffer)
|
||||||
|
|
||||||
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
|
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -138,7 +138,8 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||||
|
|
||||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
header := buf.New()
|
_header := buf.StackNew()
|
||||||
|
header := common.Dup(_header)
|
||||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
header.Release()
|
header.Release()
|
||||||
|
|
|
@ -79,9 +79,9 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay
|
||||||
}
|
}
|
||||||
|
|
||||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||||
subKey := make([]byte, keyLength)
|
subKey := buf.Make(keyLength)
|
||||||
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
||||||
common.Must1(io.ReadFull(kdf, subKey))
|
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
|
||||||
return subKey
|
return subKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,8 +111,8 @@ func (m *Method) KeyLength() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
||||||
saltBuffer := buf.Make(m.keySaltLength)
|
_salt := buf.Make(m.keySaltLength)
|
||||||
salt := common.Dup(saltBuffer)
|
salt := common.Dup(_salt)
|
||||||
_, err := io.ReadFull(upstream, salt)
|
_, err := io.ReadFull(upstream, salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "read salt")
|
return nil, E.Cause(err, "read salt")
|
||||||
|
@ -122,18 +122,20 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
||||||
return nil, E.New("salt not unique")
|
return nil, E.New("salt not unique")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return NewReader(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
|
key := Kdf(m.key, salt, m.keySaltLength)
|
||||||
|
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
|
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
|
||||||
saltBuffer := buf.Make(m.keySaltLength)
|
_salt := buf.Make(m.keySaltLength)
|
||||||
salt := common.Dup(saltBuffer)
|
salt := common.Dup(_salt)
|
||||||
common.Must1(io.ReadFull(m.secureRNG, salt))
|
common.Must1(io.ReadFull(m.secureRNG, salt))
|
||||||
_, err := upstream.Write(salt)
|
_, err := upstream.Write(salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewWriter(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
|
key := Kdf(m.key, salt, m.keySaltLength)
|
||||||
|
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||||
|
@ -154,11 +156,12 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||||
return &aeadPacketConn{conn, m}
|
return &clientPacketConn{conn, m}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||||
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
|
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||||
|
c := m.constructor(common.Dup(key))
|
||||||
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||||
buffer.Extend(c.Overhead())
|
buffer.Extend(c.Overhead())
|
||||||
return nil
|
return nil
|
||||||
|
@ -168,7 +171,8 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
||||||
if buffer.Len() < m.keySaltLength {
|
if buffer.Len() < m.keySaltLength {
|
||||||
return E.New("bad packet")
|
return E.New("bad packet")
|
||||||
}
|
}
|
||||||
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
|
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||||
|
c := m.constructor(common.Dup(key))
|
||||||
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -190,8 +194,8 @@ type clientConn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
request := buf.New()
|
_request := buf.StackNew()
|
||||||
defer request.Release()
|
request := common.Dup(_request)
|
||||||
|
|
||||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||||
|
|
||||||
|
@ -207,8 +211,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(payload) > 0 {
|
if len(payload) > 0 {
|
||||||
header := buf.New()
|
_header := buf.StackNew()
|
||||||
defer header.Release()
|
header := common.Dup(_header)
|
||||||
|
|
||||||
writer = &buf.BufferedWriter{
|
writer = &buf.BufferedWriter{
|
||||||
Writer: writer,
|
Writer: writer,
|
||||||
|
@ -240,23 +244,26 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) readResponse() error {
|
func (c *clientConn) readResponse() error {
|
||||||
if c.reader == nil {
|
if c.reader != nil {
|
||||||
salt := make([]byte, c.method.keySaltLength)
|
return nil
|
||||||
_, err := io.ReadFull(c.Conn, salt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if c.method.replayFilter != nil {
|
|
||||||
if !c.method.replayFilter.Check(salt) {
|
|
||||||
return E.New("salt not unique")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.reader = NewReader(
|
|
||||||
c.Conn,
|
|
||||||
c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)),
|
|
||||||
MaxPacketSize,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
_salt := buf.Make(c.method.keySaltLength)
|
||||||
|
salt := common.Dup(_salt)
|
||||||
|
_, err := io.ReadFull(c.Conn, salt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.method.replayFilter != nil {
|
||||||
|
if !c.method.replayFilter.Check(salt) {
|
||||||
|
return E.New("salt not unique")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||||
|
c.reader = NewReader(
|
||||||
|
c.Conn,
|
||||||
|
c.method.constructor(common.Dup(key)),
|
||||||
|
MaxPacketSize,
|
||||||
|
)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,14 +307,14 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
type aeadPacketConn struct {
|
type clientPacketConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
method *Method
|
method *Method
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||||
defer buffer.Release()
|
_header := buf.StackNew()
|
||||||
header := buf.New()
|
header := common.Dup(_header)
|
||||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -321,7 +328,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort
|
||||||
return common.Error(c.Write(buffer.Bytes()))
|
return common.Error(c.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||||
n, err := c.Read(buffer.FreeBytes())
|
n, err := c.Read(buffer.FreeBytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -22,21 +22,18 @@ import (
|
||||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||||
"github.com/sagernet/sing/protocol/socks"
|
"github.com/sagernet/sing/protocol/socks"
|
||||||
"golang.org/x/crypto/chacha20"
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
wgReplay "golang.zx2c4.com/wireguard/replay"
|
wgReplay "golang.zx2c4.com/wireguard/replay"
|
||||||
"lukechampine.com/blake3"
|
"lukechampine.com/blake3"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HeaderTypeClient = 0
|
HeaderTypeClient = 0
|
||||||
HeaderTypeServer = 1
|
HeaderTypeServer = 1
|
||||||
MaxPaddingLength = 900
|
MaxPaddingLength = 900
|
||||||
KeySaltSize = 32
|
KeySaltSize = 32
|
||||||
PacketNonceSize = 24
|
PacketNonceSize = 24
|
||||||
MinRequestHeaderSize = 1 + 8
|
MaxPacketSize = 65535
|
||||||
MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize
|
|
||||||
MaxPacketSize = 65535 + shadowaead.PacketLengthBufferSize + nonceSize*2
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -106,7 +103,6 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
||||||
case "2022-blake3-chacha20-poly1305":
|
case "2022-blake3-chacha20-poly1305":
|
||||||
m.keyLength = 32
|
m.keyLength = 32
|
||||||
m.constructor = newChacha20Poly1305
|
m.constructor = newChacha20Poly1305
|
||||||
m.streamConstructor = newChacha20
|
|
||||||
m.udpCipher = newXChacha20Poly1305(m.psk)
|
m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
|
@ -135,12 +131,6 @@ func newAESGCM(key []byte) cipher.AEAD {
|
||||||
return aead
|
return aead
|
||||||
}
|
}
|
||||||
|
|
||||||
func newChacha20(key []byte) cipher.Stream {
|
|
||||||
_nonce := make([]byte, chacha20.NonceSize)
|
|
||||||
stream, _ := chacha20.NewUnauthenticatedCipher(key, common.Dup(_nonce))
|
|
||||||
return stream
|
|
||||||
}
|
|
||||||
|
|
||||||
func newChacha20Poly1305(key []byte) cipher.AEAD {
|
func newChacha20Poly1305(key []byte) cipher.AEAD {
|
||||||
cipher, err := chacha20poly1305.New(key)
|
cipher, err := chacha20poly1305.New(key)
|
||||||
common.Must(err)
|
common.Must(err)
|
||||||
|
@ -154,18 +144,17 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Method struct {
|
type Method struct {
|
||||||
name string
|
name string
|
||||||
keyLength int
|
keyLength int
|
||||||
constructor func(key []byte) cipher.AEAD
|
constructor func(key []byte) cipher.AEAD
|
||||||
blockConstructor func(key []byte) cipher.Block
|
blockConstructor func(key []byte) cipher.Block
|
||||||
streamConstructor func(key []byte) cipher.Stream
|
udpCipher cipher.AEAD
|
||||||
udpCipher cipher.AEAD
|
udpBlockCipher cipher.Block
|
||||||
udpBlockCipher cipher.Block
|
psk []byte
|
||||||
psk []byte
|
pskList [][]byte
|
||||||
pskList [][]byte
|
pskHash []byte
|
||||||
pskHash []byte
|
secureRNG io.Reader
|
||||||
secureRNG io.Reader
|
replayFilter replay.Filter
|
||||||
replayFilter replay.Filter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) Name() string {
|
func (m *Method) Name() string {
|
||||||
|
@ -176,30 +165,6 @@ func (m *Method) KeyLength() int {
|
||||||
return m.keyLength
|
return m.keyLength
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) WriteExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
|
||||||
pskLen := len(m.pskList)
|
|
||||||
if pskLen < 2 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i, psk := range m.pskList {
|
|
||||||
keyMaterial := make([]byte, 2*KeySaltSize)
|
|
||||||
copy(keyMaterial, psk)
|
|
||||||
copy(keyMaterial[KeySaltSize:], salt)
|
|
||||||
_identitySubkey := buf.Make(m.keyLength)
|
|
||||||
identitySubkey := common.Dup(_identitySubkey)
|
|
||||||
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
|
||||||
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
|
||||||
if m.blockConstructor != nil {
|
|
||||||
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
|
|
||||||
} else {
|
|
||||||
m.streamConstructor(identitySubkey).XORKeyStream(request.Extend(16), pskHash)
|
|
||||||
}
|
|
||||||
if i == pskLen-2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||||
shadowsocksConn := &clientConn{
|
shadowsocksConn := &clientConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
|
@ -236,18 +201,38 @@ type clientConn struct {
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||||
|
pskLen := len(m.pskList)
|
||||||
|
if pskLen < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, psk := range m.pskList {
|
||||||
|
keyMaterial := make([]byte, 2*KeySaltSize)
|
||||||
|
copy(keyMaterial, psk)
|
||||||
|
copy(keyMaterial[KeySaltSize:], salt)
|
||||||
|
_identitySubkey := buf.Make(m.keyLength)
|
||||||
|
identitySubkey := common.Dup(_identitySubkey)
|
||||||
|
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
||||||
|
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||||
|
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
|
||||||
|
if i == pskLen-2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
request := buf.New()
|
_request := buf.StackNew()
|
||||||
defer request.Release()
|
request := common.Dup(_request)
|
||||||
|
|
||||||
salt := make([]byte, KeySaltSize)
|
salt := make([]byte, KeySaltSize)
|
||||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||||
common.Must1(request.Write(salt))
|
common.Must1(request.Write(salt))
|
||||||
c.method.WriteExtendedIdentityHeaders(request, salt)
|
c.method.writeExtendedIdentityHeaders(request, salt)
|
||||||
|
|
||||||
var writer io.Writer = c.Conn
|
var writer io.Writer
|
||||||
writer = &buf.BufferedWriter{
|
writer = &buf.BufferedWriter{
|
||||||
Writer: writer,
|
Writer: c.Conn,
|
||||||
Buffer: request,
|
Buffer: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,8 +243,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
||||||
MaxPacketSize,
|
MaxPacketSize,
|
||||||
)
|
)
|
||||||
|
|
||||||
header := buf.New()
|
_header := buf.StackNew()
|
||||||
defer header.Release()
|
header := common.Dup(_header)
|
||||||
|
|
||||||
writer = &buf.BufferedWriter{
|
writer = &buf.BufferedWriter{
|
||||||
Writer: writer,
|
Writer: writer,
|
||||||
|
@ -362,6 +347,7 @@ func (c *clientConn) readResponse() error {
|
||||||
return ErrBadRequestSalt
|
return ErrBadRequestSalt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.requestSalt = nil
|
||||||
c.reader = reader
|
c.reader = reader
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -417,23 +403,14 @@ type clientPacketConn struct {
|
||||||
|
|
||||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
header := buf.New()
|
_header := buf.StackNew()
|
||||||
|
header := common.Dup(_header)
|
||||||
pskLen := len(c.method.pskList)
|
pskLen := len(c.method.pskList)
|
||||||
var dataIndex int
|
var dataIndex int
|
||||||
if c.method.udpCipher != nil {
|
if c.method.udpCipher != nil {
|
||||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
||||||
if pskLen > 1 {
|
if pskLen > 1 {
|
||||||
for i, psk := range c.method.pskList {
|
panic("unsupported chacha extended header")
|
||||||
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
|
||||||
identityHeader := header.Extend(aes.BlockSize)
|
|
||||||
for textI := 0; textI < aes.BlockSize; textI++ {
|
|
||||||
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
|
|
||||||
}
|
|
||||||
c.method.streamConstructor(psk).XORKeyStream(identityHeader, identityHeader)
|
|
||||||
if i == pskLen-2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
dataIndex = buffer.Len()
|
dataIndex = buffer.Len()
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue