Fix stream client

This commit is contained in:
世界 2022-05-30 11:26:25 +08:00
parent 77a38dfcfc
commit 9fbb103c01
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -11,7 +11,6 @@ import (
"os"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
@ -77,7 +76,7 @@ func New(method string, key []byte, password string) (shadowsocks.Method, error)
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
case "rc4-md5":
m.keyLength = 16
m.saltLength = 0
m.saltLength = 16
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
h := md5.New()
h.Write(key)
@ -143,17 +142,17 @@ func (m *Method) KeyLength() int {
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
shadowsocksConn := &clientConn{
Method: m,
Conn: conn,
method: m,
destination: destination,
}
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
return shadowsocksConn, shadowsocksConn.writeRequest()
}
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
return &clientConn{
Method: m,
Conn: conn,
method: m,
destination: destination,
}
}
@ -163,45 +162,40 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
}
type clientConn struct {
*Method
net.Conn
method *Method
destination M.Socksaddr
readStream cipher.Stream
writeStream cipher.Stream
}
func (c *clientConn) writeRequest(payload []byte) error {
_buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload))
func (c *clientConn) writeRequest() error {
_buffer := buf.StackNewSize(c.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))
defer common.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
buffer := common.Dup(_buffer)
defer buffer.Release()
salt := buffer.Extend(c.method.keyLength)
salt := buffer.Extend(c.saltLength)
common.Must1(io.ReadFull(rand.Reader, salt))
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
writer, err := c.method.encryptConstructor(c.method.key, salt)
stream, err := c.encryptConstructor(c.key, salt)
if err != nil {
return err
}
common.KeepAlive(key)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
return err
}
_, err = buffer.Write(payload)
if err != nil {
return err
}
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
return err
}
c.writeStream = writer
c.writeStream = stream
return nil
}
@ -209,29 +203,27 @@ func (c *clientConn) readResponse() error {
if c.readStream != nil {
return nil
}
_salt := buf.Make(c.method.keyLength)
_salt := buf.Make(c.saltLength)
defer common.KeepAlive(_salt)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
defer common.KeepAlive(key)
c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt)
if err != nil {
return err
}
return nil
c.readStream, err = c.decryptConstructor(c.key, salt)
return err
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if err = c.readResponse(); err != nil {
return
if c.readStream == nil {
err = c.readResponse()
if err != nil {
return
}
}
n, err = c.Conn.Read(p)
if err != nil {
return 0, err
return
}
c.readStream.XORKeyStream(p[:n], p[:n])
return
@ -239,11 +231,10 @@ func (c *clientConn) Read(p []byte) (n int, err error) {
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writeStream == nil {
err = c.writeRequest(p)
if err == nil {
n = len(p)
err = c.writeRequest()
if err != nil {
return
}
return
}
c.writeStream.XORKeyStream(p, p)
@ -261,17 +252,17 @@ type clientPacketConn struct {
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength))
header := buf.With(buffer.ExtendHeader(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
common.Must1(header.ReadFullFrom(rand.Reader, c.saltLength))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil {
return err
}
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
if err != nil {
return err
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
return common.Error(c.Write(buffer.Bytes()))
}
@ -281,12 +272,12 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return M.Socksaddr{}, err
}
buffer.Truncate(n)
stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength))
stream, err := c.decryptConstructor(c.key, buffer.To(c.saltLength))
if err != nil {
return M.Socksaddr{}, err
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
buffer.Advance(c.keyLength)
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
buffer.Advance(c.saltLength)
return M.SocksaddrSerializer.ReadAddrPort(buffer)
}
@ -295,11 +286,11 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
if err != nil {
return
}
stream, err := c.decryptConstructor(c.key, p[:c.keyLength])
stream, err := c.decryptConstructor(c.key, p[:c.saltLength])
if err != nil {
return
}
buffer := buf.With(p[c.keyLength:n])
buffer := buf.As(p[c.saltLength:n])
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
@ -312,10 +303,10 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
_buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
_buffer := buf.Make(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
defer common.KeepAlive(_buffer)
buffer := buf.With(common.Dup(_buffer))
common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength))
common.Must1(buffer.ReadFullFrom(rand.Reader, c.saltLength))
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {
return
@ -324,11 +315,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if err != nil {
return
}
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
if err != nil {
return
}
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
_, err = c.Write(buffer.Bytes())
if err != nil {
return