mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-03 20:07:40 +03:00
Fix stream client
This commit is contained in:
parent
77a38dfcfc
commit
9fbb103c01
1 changed files with 39 additions and 48 deletions
|
@ -11,7 +11,6 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks"
|
||||
"github.com/sagernet/sing-shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
@ -77,7 +76,7 @@ func New(method string, key []byte, password string) (shadowsocks.Method, error)
|
|||
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
|
||||
case "rc4-md5":
|
||||
m.keyLength = 16
|
||||
m.saltLength = 0
|
||||
m.saltLength = 16
|
||||
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
|
||||
h := md5.New()
|
||||
h.Write(key)
|
||||
|
@ -143,17 +142,17 @@ func (m *Method) KeyLength() int {
|
|||
|
||||
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
|
||||
shadowsocksConn := &clientConn{
|
||||
Method: m,
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
||||
return shadowsocksConn, shadowsocksConn.writeRequest()
|
||||
}
|
||||
|
||||
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
|
||||
return &clientConn{
|
||||
Method: m,
|
||||
Conn: conn,
|
||||
method: m,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
@ -163,45 +162,40 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
|||
}
|
||||
|
||||
type clientConn struct {
|
||||
*Method
|
||||
net.Conn
|
||||
|
||||
method *Method
|
||||
destination M.Socksaddr
|
||||
|
||||
readStream cipher.Stream
|
||||
writeStream cipher.Stream
|
||||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload))
|
||||
func (c *clientConn) writeRequest() error {
|
||||
_buffer := buf.StackNewSize(c.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := buf.With(common.Dup(_buffer))
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
|
||||
salt := buffer.Extend(c.method.keyLength)
|
||||
salt := buffer.Extend(c.saltLength)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
|
||||
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
|
||||
writer, err := c.method.encryptConstructor(c.method.key, salt)
|
||||
stream, err := c.encryptConstructor(c.key, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.KeepAlive(key)
|
||||
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buffer.Write(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writeStream = writer
|
||||
c.writeStream = stream
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -209,29 +203,27 @@ func (c *clientConn) readResponse() error {
|
|||
if c.readStream != nil {
|
||||
return nil
|
||||
}
|
||||
_salt := buf.Make(c.method.keyLength)
|
||||
_salt := buf.Make(c.saltLength)
|
||||
defer common.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
|
||||
defer common.KeepAlive(key)
|
||||
c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
c.readStream, err = c.decryptConstructor(c.key, salt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *clientConn) Read(p []byte) (n int, err error) {
|
||||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
if c.readStream == nil {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
n, err = c.Conn.Read(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
c.readStream.XORKeyStream(p[:n], p[:n])
|
||||
return
|
||||
|
@ -239,11 +231,10 @@ func (c *clientConn) Read(p []byte) (n int, err error) {
|
|||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
if c.writeStream == nil {
|
||||
err = c.writeRequest(p)
|
||||
if err == nil {
|
||||
n = len(p)
|
||||
err = c.writeRequest()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.writeStream.XORKeyStream(p, p)
|
||||
|
@ -261,17 +252,17 @@ type clientPacketConn struct {
|
|||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination)))
|
||||
common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength))
|
||||
header := buf.With(buffer.ExtendHeader(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
|
||||
common.Must1(header.ReadFullFrom(rand.Reader, c.saltLength))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
|
||||
stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
|
||||
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
|
@ -281,12 +272,12 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
return M.Socksaddr{}, err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength))
|
||||
stream, err := c.decryptConstructor(c.key, buffer.To(c.saltLength))
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
|
||||
buffer.Advance(c.keyLength)
|
||||
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||
buffer.Advance(c.saltLength)
|
||||
return M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||
}
|
||||
|
||||
|
@ -295,11 +286,11 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
stream, err := c.decryptConstructor(c.key, p[:c.keyLength])
|
||||
stream, err := c.decryptConstructor(c.key, p[:c.saltLength])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buffer := buf.With(p[c.keyLength:n])
|
||||
buffer := buf.As(p[c.saltLength:n])
|
||||
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||
if err != nil {
|
||||
|
@ -312,10 +303,10 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
|||
|
||||
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
destination := M.SocksaddrFromNet(addr)
|
||||
_buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
|
||||
_buffer := buf.Make(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := buf.With(common.Dup(_buffer))
|
||||
common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength))
|
||||
common.Must1(buffer.ReadFullFrom(rand.Reader, c.saltLength))
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -324,11 +315,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
|
||||
stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
|
||||
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||
_, err = c.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue