Fix stream client

This commit is contained in:
世界 2022-05-30 11:26:25 +08:00
parent 77a38dfcfc
commit 9fbb103c01
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -11,7 +11,6 @@ import (
"os" "os"
"github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -77,7 +76,7 @@ func New(method string, key []byte, password string) (shadowsocks.Method, error)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
case "rc4-md5": case "rc4-md5":
m.keyLength = 16 m.keyLength = 16
m.saltLength = 0 m.saltLength = 16
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
h := md5.New() h := md5.New()
h.Write(key) h.Write(key)
@ -143,17 +142,17 @@ func (m *Method) KeyLength() int {
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{ shadowsocksConn := &clientConn{
Method: m,
Conn: conn, Conn: conn,
method: m,
destination: destination, destination: destination,
} }
return shadowsocksConn, shadowsocksConn.writeRequest(nil) return shadowsocksConn, shadowsocksConn.writeRequest()
} }
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{ return &clientConn{
Method: m,
Conn: conn, Conn: conn,
method: m,
destination: destination, destination: destination,
} }
} }
@ -163,45 +162,40 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
} }
type clientConn struct { type clientConn struct {
*Method
net.Conn net.Conn
method *Method
destination M.Socksaddr destination M.Socksaddr
readStream cipher.Stream readStream cipher.Stream
writeStream cipher.Stream writeStream cipher.Stream
} }
func (c *clientConn) writeRequest(payload []byte) error { func (c *clientConn) writeRequest() error {
_buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload)) _buffer := buf.StackNewSize(c.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))
defer common.KeepAlive(_buffer) defer common.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer)) buffer := common.Dup(_buffer)
defer buffer.Release()
salt := buffer.Extend(c.method.keyLength) salt := buffer.Extend(c.saltLength)
common.Must1(io.ReadFull(rand.Reader, salt)) common.Must1(io.ReadFull(rand.Reader, salt))
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength) stream, err := c.encryptConstructor(c.key, salt)
writer, err := c.method.encryptConstructor(c.method.key, salt)
if err != nil { if err != nil {
return err return err
} }
common.KeepAlive(key)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
if err != nil { if err != nil {
return err return err
} }
_, err = buffer.Write(payload)
if err != nil { stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
return err
}
_, err = c.Conn.Write(buffer.Bytes()) _, err = c.Conn.Write(buffer.Bytes())
if err != nil { if err != nil {
return err return err
} }
c.writeStream = writer c.writeStream = stream
return nil return nil
} }
@ -209,29 +203,27 @@ func (c *clientConn) readResponse() error {
if c.readStream != nil { if c.readStream != nil {
return nil return nil
} }
_salt := buf.Make(c.method.keyLength) _salt := buf.Make(c.saltLength)
defer common.KeepAlive(_salt) defer common.KeepAlive(_salt)
salt := common.Dup(_salt) salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt) _, err := io.ReadFull(c.Conn, salt)
if err != nil { if err != nil {
return err return err
} }
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength) c.readStream, err = c.decryptConstructor(c.key, salt)
defer common.KeepAlive(key) return err
c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt)
if err != nil {
return err
}
return nil
} }
func (c *clientConn) Read(p []byte) (n int, err error) { func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil { if c.readStream == nil {
return err = c.readResponse()
if err != nil {
return
}
} }
n, err = c.Conn.Read(p) n, err = c.Conn.Read(p)
if err != nil { if err != nil {
return 0, err return
} }
c.readStream.XORKeyStream(p[:n], p[:n]) c.readStream.XORKeyStream(p[:n], p[:n])
return return
@ -239,11 +231,10 @@ func (c *clientConn) Read(p []byte) (n int, err error) {
func (c *clientConn) Write(p []byte) (n int, err error) { func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writeStream == nil { if c.writeStream == nil {
err = c.writeRequest(p) err = c.writeRequest()
if err == nil { if err != nil {
n = len(p) return
} }
return
} }
c.writeStream.XORKeyStream(p, p) c.writeStream.XORKeyStream(p, p)
@ -261,17 +252,17 @@ type clientPacketConn struct {
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release() defer buffer.Release()
header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination))) header := buf.With(buffer.ExtendHeader(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength)) common.Must1(header.ReadFullFrom(rand.Reader, c.saltLength))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination) err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil { if err != nil {
return err return err
} }
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength)) stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
if err != nil { if err != nil {
return err return err
} }
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
return common.Error(c.Write(buffer.Bytes())) return common.Error(c.Write(buffer.Bytes()))
} }
@ -281,12 +272,12 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return M.Socksaddr{}, err return M.Socksaddr{}, err
} }
buffer.Truncate(n) buffer.Truncate(n)
stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength)) stream, err := c.decryptConstructor(c.key, buffer.To(c.saltLength))
if err != nil { if err != nil {
return M.Socksaddr{}, err return M.Socksaddr{}, err
} }
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
buffer.Advance(c.keyLength) buffer.Advance(c.saltLength)
return M.SocksaddrSerializer.ReadAddrPort(buffer) return M.SocksaddrSerializer.ReadAddrPort(buffer)
} }
@ -295,11 +286,11 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
if err != nil { if err != nil {
return return
} }
stream, err := c.decryptConstructor(c.key, p[:c.keyLength]) stream, err := c.decryptConstructor(c.key, p[:c.saltLength])
if err != nil { if err != nil {
return return
} }
buffer := buf.With(p[c.keyLength:n]) buffer := buf.As(p[c.saltLength:n])
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil { if err != nil {
@ -312,10 +303,10 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr) destination := M.SocksaddrFromNet(addr)
_buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) _buffer := buf.Make(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer common.KeepAlive(_buffer) defer common.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer)) buffer := buf.With(common.Dup(_buffer))
common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength)) common.Must1(buffer.ReadFullFrom(rand.Reader, c.saltLength))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil { if err != nil {
return return
@ -324,11 +315,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if err != nil { if err != nil {
return return
} }
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength)) stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
if err != nil { if err != nil {
return return
} }
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength)) stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
_, err = c.Write(buffer.Bytes()) _, err = c.Write(buffer.Bytes())
if err != nil { if err != nil {
return return