mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-04 12:27:39 +03:00
Fix stream client
This commit is contained in:
parent
77a38dfcfc
commit
9fbb103c01
1 changed files with 39 additions and 48 deletions
|
@ -11,7 +11,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sagernet/sing-shadowsocks"
|
"github.com/sagernet/sing-shadowsocks"
|
||||||
"github.com/sagernet/sing-shadowsocks/shadowaead"
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
@ -77,7 +76,7 @@ func New(method string, key []byte, password string) (shadowsocks.Method, error)
|
||||||
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
|
m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
|
||||||
case "rc4-md5":
|
case "rc4-md5":
|
||||||
m.keyLength = 16
|
m.keyLength = 16
|
||||||
m.saltLength = 0
|
m.saltLength = 16
|
||||||
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
|
m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
|
||||||
h := md5.New()
|
h := md5.New()
|
||||||
h.Write(key)
|
h.Write(key)
|
||||||
|
@ -143,17 +142,17 @@ func (m *Method) KeyLength() int {
|
||||||
|
|
||||||
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
|
func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
|
||||||
shadowsocksConn := &clientConn{
|
shadowsocksConn := &clientConn{
|
||||||
|
Method: m,
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
method: m,
|
|
||||||
destination: destination,
|
destination: destination,
|
||||||
}
|
}
|
||||||
return shadowsocksConn, shadowsocksConn.writeRequest(nil)
|
return shadowsocksConn, shadowsocksConn.writeRequest()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
|
func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
|
||||||
return &clientConn{
|
return &clientConn{
|
||||||
|
Method: m,
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
method: m,
|
|
||||||
destination: destination,
|
destination: destination,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -163,45 +162,40 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientConn struct {
|
type clientConn struct {
|
||||||
|
*Method
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
method *Method
|
|
||||||
destination M.Socksaddr
|
destination M.Socksaddr
|
||||||
|
|
||||||
readStream cipher.Stream
|
readStream cipher.Stream
|
||||||
writeStream cipher.Stream
|
writeStream cipher.Stream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) writeRequest(payload []byte) error {
|
func (c *clientConn) writeRequest() error {
|
||||||
_buffer := buf.Make(c.method.keyLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(payload))
|
_buffer := buf.StackNewSize(c.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))
|
||||||
defer common.KeepAlive(_buffer)
|
defer common.KeepAlive(_buffer)
|
||||||
buffer := buf.With(common.Dup(_buffer))
|
buffer := common.Dup(_buffer)
|
||||||
|
defer buffer.Release()
|
||||||
|
|
||||||
salt := buffer.Extend(c.method.keyLength)
|
salt := buffer.Extend(c.saltLength)
|
||||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||||
|
|
||||||
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
|
stream, err := c.encryptConstructor(c.key, salt)
|
||||||
writer, err := c.method.encryptConstructor(c.method.key, salt)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
common.KeepAlive(key)
|
|
||||||
|
|
||||||
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
|
err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = buffer.Write(payload)
|
|
||||||
if err != nil {
|
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.Conn.Write(buffer.Bytes())
|
_, err = c.Conn.Write(buffer.Bytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writeStream = writer
|
c.writeStream = stream
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,29 +203,27 @@ func (c *clientConn) readResponse() error {
|
||||||
if c.readStream != nil {
|
if c.readStream != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
_salt := buf.Make(c.method.keyLength)
|
_salt := buf.Make(c.saltLength)
|
||||||
defer common.KeepAlive(_salt)
|
defer common.KeepAlive(_salt)
|
||||||
salt := common.Dup(_salt)
|
salt := common.Dup(_salt)
|
||||||
_, err := io.ReadFull(c.Conn, salt)
|
_, err := io.ReadFull(c.Conn, salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
key := shadowaead.Kdf(c.method.key, salt, c.method.keyLength)
|
c.readStream, err = c.decryptConstructor(c.key, salt)
|
||||||
defer common.KeepAlive(key)
|
return err
|
||||||
c.readStream, err = c.method.decryptConstructor(common.Dup(key), salt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) Read(p []byte) (n int, err error) {
|
func (c *clientConn) Read(p []byte) (n int, err error) {
|
||||||
if err = c.readResponse(); err != nil {
|
if c.readStream == nil {
|
||||||
return
|
err = c.readResponse()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
n, err = c.Conn.Read(p)
|
n, err = c.Conn.Read(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return
|
||||||
}
|
}
|
||||||
c.readStream.XORKeyStream(p[:n], p[:n])
|
c.readStream.XORKeyStream(p[:n], p[:n])
|
||||||
return
|
return
|
||||||
|
@ -239,11 +231,10 @@ func (c *clientConn) Read(p []byte) (n int, err error) {
|
||||||
|
|
||||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||||
if c.writeStream == nil {
|
if c.writeStream == nil {
|
||||||
err = c.writeRequest(p)
|
err = c.writeRequest()
|
||||||
if err == nil {
|
if err != nil {
|
||||||
n = len(p)
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writeStream.XORKeyStream(p, p)
|
c.writeStream.XORKeyStream(p, p)
|
||||||
|
@ -261,17 +252,17 @@ type clientPacketConn struct {
|
||||||
|
|
||||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
header := buf.With(buffer.ExtendHeader(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination)))
|
header := buf.With(buffer.ExtendHeader(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
|
||||||
common.Must1(header.ReadFullFrom(rand.Reader, c.keyLength))
|
common.Must1(header.ReadFullFrom(rand.Reader, c.saltLength))
|
||||||
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
|
stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
|
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||||
return common.Error(c.Write(buffer.Bytes()))
|
return common.Error(c.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,12 +272,12 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
return M.Socksaddr{}, err
|
return M.Socksaddr{}, err
|
||||||
}
|
}
|
||||||
buffer.Truncate(n)
|
buffer.Truncate(n)
|
||||||
stream, err := c.decryptConstructor(c.key, buffer.To(c.keyLength))
|
stream, err := c.decryptConstructor(c.key, buffer.To(c.saltLength))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return M.Socksaddr{}, err
|
return M.Socksaddr{}, err
|
||||||
}
|
}
|
||||||
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
|
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||||
buffer.Advance(c.keyLength)
|
buffer.Advance(c.saltLength)
|
||||||
return M.SocksaddrSerializer.ReadAddrPort(buffer)
|
return M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,11 +286,11 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
stream, err := c.decryptConstructor(c.key, p[:c.keyLength])
|
stream, err := c.decryptConstructor(c.key, p[:c.saltLength])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffer := buf.With(p[c.keyLength:n])
|
buffer := buf.As(p[c.saltLength:n])
|
||||||
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
|
stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
|
||||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
|
destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -312,10 +303,10 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
||||||
|
|
||||||
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
destination := M.SocksaddrFromNet(addr)
|
destination := M.SocksaddrFromNet(addr)
|
||||||
_buffer := buf.Make(c.keyLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
|
_buffer := buf.Make(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
|
||||||
defer common.KeepAlive(_buffer)
|
defer common.KeepAlive(_buffer)
|
||||||
buffer := buf.With(common.Dup(_buffer))
|
buffer := buf.With(common.Dup(_buffer))
|
||||||
common.Must1(buffer.ReadFullFrom(rand.Reader, c.keyLength))
|
common.Must1(buffer.ReadFullFrom(rand.Reader, c.saltLength))
|
||||||
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
|
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -324,11 +315,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
stream, err := c.encryptConstructor(c.key, buffer.To(c.keyLength))
|
stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
stream.XORKeyStream(buffer.From(c.keyLength), buffer.From(c.keyLength))
|
stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
|
||||||
_, err = c.Write(buffer.Bytes())
|
_, err = c.Write(buffer.Bytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue