mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Refine buffer
This commit is contained in:
parent
31d4b88581
commit
f16dd7a336
30 changed files with 993 additions and 209 deletions
|
@ -95,6 +95,7 @@ func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Auth
|
|||
left, right := net.Pipe()
|
||||
go func() {
|
||||
metadata.Destination = destination
|
||||
metadata.Protocol = "http"
|
||||
err = handler.NewConnection(right, metadata)
|
||||
if err != nil {
|
||||
handler.HandleError(&tcp.Error{Conn: right, Cause: err})
|
||||
|
|
|
@ -117,7 +117,8 @@ func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
}
|
||||
|
||||
func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.Conn.(io.WriterTo).WriteTo(w)
|
||||
return io.Copy(w, c.Conn)
|
||||
// return c.Conn.(io.WriterTo).WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *noneConn) RemoteAddr() net.Addr {
|
||||
|
@ -138,7 +139,7 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
|
||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
_header := buf.StackNew()
|
||||
_header := buf.StackNewMax()
|
||||
header := common.Dup(_header)
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||
if err != nil {
|
||||
|
|
42
protocol/shadowsocks/service.go
Normal file
42
protocol/shadowsocks/service.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package shadowsocks
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
M.TCPConnectionHandler
|
||||
}
|
||||
|
||||
type MultiUserService interface {
|
||||
Service
|
||||
AddUser(key []byte)
|
||||
RemoveUser(key []byte)
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
M.TCPConnectionHandler
|
||||
}
|
||||
|
||||
type NoneService struct {
|
||||
handler Handler
|
||||
}
|
||||
|
||||
func NewNoneService(handler Handler) Service {
|
||||
return &NoneService{
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NoneService) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(conn, metadata)
|
||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
|
@ -92,6 +91,46 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Reader) readInternal() (err error) {
|
||||
start := PacketLengthBufferSize + r.cipher.Overhead()
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:start])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
|
||||
end := length + r.cipher.Overhead()
|
||||
_, err = io.ReadFull(r.upstream, r.buffer[:end])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increaseNonce(r.nonce)
|
||||
r.cached = length
|
||||
r.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReadByte() (byte, error) {
|
||||
if r.cached == 0 {
|
||||
err := r.readInternal()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
index := r.index
|
||||
r.index++
|
||||
r.cached--
|
||||
return r.buffer[index], nil
|
||||
}
|
||||
|
||||
func (r *Reader) Read(b []byte) (n int, err error) {
|
||||
if r.cached > 0 {
|
||||
n = copy(b, r.buffer[r.index:r.index+r.cached])
|
||||
|
@ -141,6 +180,24 @@ func (r *Reader) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Discard(n int) error {
|
||||
for {
|
||||
if r.cached >= n {
|
||||
r.cached -= n
|
||||
r.index += n
|
||||
return nil
|
||||
} else if r.cached > 0 {
|
||||
n -= r.cached
|
||||
r.cached = 0
|
||||
r.index = 0
|
||||
}
|
||||
err := r.readInternal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
upstream io.Writer
|
||||
cipher cipher.AEAD
|
||||
|
@ -197,10 +254,6 @@ func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = common.FlushVar(&w.upstream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(readN)
|
||||
}
|
||||
}
|
||||
|
@ -227,6 +280,70 @@ func (w *Writer) Write(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *Writer) Buffer() *buf.Buffer {
|
||||
return buf.With(w.buffer)
|
||||
}
|
||||
|
||||
func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
upstream: w,
|
||||
reversed: reversed,
|
||||
data: w.buffer[PacketLengthBufferSize+w.cipher.Overhead() : len(w.buffer)-w.cipher.Overhead()],
|
||||
}
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
upstream *Writer
|
||||
data []byte
|
||||
reversed int
|
||||
index int
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Upstream() io.Writer {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Replaceable() bool {
|
||||
return w.index == 0
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
var index int
|
||||
for {
|
||||
cachedN := copy(w.data[w.reversed+w.index:], p[index:])
|
||||
if cachedN == len(p[index:]) {
|
||||
w.index += cachedN
|
||||
return cachedN, nil
|
||||
}
|
||||
err = w.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
index += cachedN
|
||||
}
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
if w.index == 0 {
|
||||
if w.reversed > 0 {
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
|
||||
w.reversed = 0
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
buffer := w.upstream.buffer[w.reversed:]
|
||||
binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
|
||||
w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
offset := w.upstream.cipher.Overhead() + PacketLengthBufferSize
|
||||
packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
|
||||
increaseNonce(w.upstream.nonce)
|
||||
_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
|
||||
w.reversed = 0
|
||||
return err
|
||||
}
|
||||
|
||||
func increaseNonce(nonce []byte) {
|
||||
for i := range nonce {
|
||||
nonce[i]++
|
||||
|
|
|
@ -189,56 +189,47 @@ type clientConn struct {
|
|||
destination *M.AddrPort
|
||||
|
||||
access sync.Mutex
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
reader *Reader
|
||||
writer *Writer
|
||||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
_salt := make([]byte, c.method.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
|
||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: request,
|
||||
}
|
||||
writer = NewWriter(
|
||||
writer,
|
||||
c.method.constructor(Kdf(c.method.key, request.Bytes(), c.method.keySaltLength)),
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
if len(payload) > 0 {
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = writer.Write(payload)
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := common.FlushVar(&writer)
|
||||
err := bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return nil
|
||||
}
|
||||
|
@ -278,7 +269,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
|
@ -302,9 +293,9 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
|||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing handshake")
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
|
|
164
protocol/shadowsocks/shadowaead/service.go
Normal file
164
protocol/shadowsocks/shadowaead/service.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package shadowaead
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
key []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
}
|
||||
|
||||
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
s := &Service{
|
||||
name: method,
|
||||
secureRNG: secureRNG,
|
||||
handler: handler,
|
||||
}
|
||||
if replayFilter {
|
||||
s.replayFilter = replay.NewBloomRing()
|
||||
}
|
||||
switch method {
|
||||
case "aes-128-gcm":
|
||||
s.keySaltLength = 16
|
||||
s.constructor = newAESGCM
|
||||
case "aes-192-gcm":
|
||||
s.keySaltLength = 24
|
||||
s.constructor = newAESGCM
|
||||
case "aes-256-gcm":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = newAESGCM
|
||||
case "chacha20-ietf-poly1305":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
case "xchacha20-ietf-poly1305":
|
||||
s.keySaltLength = 32
|
||||
s.constructor = func(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.NewX(key)
|
||||
common.Must(err)
|
||||
return cipher
|
||||
}
|
||||
}
|
||||
if len(key) == s.keySaltLength {
|
||||
s.key = key
|
||||
} else if len(key) > 0 {
|
||||
return nil, ErrBadKey
|
||||
} else if len(password) > 0 {
|
||||
s.key = shadowsocks.Key(password, s.keySaltLength)
|
||||
} else {
|
||||
return nil, ErrMissingPassword
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
_salt := buf.Make(s.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
|
||||
_, err := io.ReadFull(conn, salt)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read salt")
|
||||
}
|
||||
|
||||
key := Kdf(s.key, salt, s.keySaltLength)
|
||||
reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize)
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
|
||||
return s.handler.NewConnection(&serverConn{
|
||||
Service: s,
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
*Service
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
reader *Reader
|
||||
writer *Writer
|
||||
}
|
||||
|
||||
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||
_salt := buf.Make(c.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||
|
||||
key := Kdf(c.key, salt, c.keySaltLength)
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
if len(payload) > 0 {
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
c.access.Lock()
|
||||
if c.writer != nil {
|
||||
c.access.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.writeResponse(p)
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer != nil {
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
|
@ -197,8 +197,8 @@ type clientConn struct {
|
|||
|
||||
requestSalt []byte
|
||||
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
reader *shadowaead.Reader
|
||||
writer *shadowaead.Writer
|
||||
}
|
||||
|
||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||
|
@ -222,68 +222,56 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
|
|||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
|
||||
salt := make([]byte, KeySaltSize)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
common.Must1(request.Write(salt))
|
||||
c.method.writeExtendedIdentityHeaders(request, salt)
|
||||
|
||||
var writer io.Writer
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: c.Conn,
|
||||
Buffer: request,
|
||||
}
|
||||
|
||||
key := Blake3DeriveKey(c.method.psk, salt, c.method.keyLength)
|
||||
writer = shadowaead.NewWriter(
|
||||
writer,
|
||||
writer := shadowaead.NewWriter(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
c.method.writeExtendedIdentityHeaders(header, salt)
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Buffer: header,
|
||||
}
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
common.Must(rw.WriteByte(writer, HeaderTypeClient))
|
||||
common.Must(binary.Write(writer, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient))
|
||||
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
|
||||
err := socks.AddressSerializer.WriteAddrPort(writer, c.destination)
|
||||
err := socks.AddressSerializer.WriteAddrPort(bufferedWriter, c.destination)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write destination")
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(0))
|
||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = writer.Write(payload)
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
}
|
||||
} else {
|
||||
pLen := rand.Intn(MaxPaddingLength + 1)
|
||||
err = binary.Write(writer, binary.BigEndian, uint16(pLen))
|
||||
err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding length")
|
||||
}
|
||||
_, err = io.CopyN(writer, c.method.secureRNG, int64(pLen))
|
||||
_, err = io.CopyN(bufferedWriter, c.method.secureRNG, int64(pLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write padding")
|
||||
}
|
||||
}
|
||||
|
||||
err = common.FlushVar(&writer)
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
|
||||
c.requestSalt = salt
|
||||
c.writer = writer
|
||||
return nil
|
||||
|
@ -363,7 +351,7 @@ func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if err = c.readResponse(); err != nil {
|
||||
return
|
||||
}
|
||||
return c.reader.(io.WriterTo).WriteTo(w)
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(p []byte) (n int, err error) {
|
||||
|
@ -389,10 +377,10 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
|||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer == nil {
|
||||
panic("missing client handshake")
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
|
||||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
|
@ -540,7 +528,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
c.session.lastFilter = c.session.filter
|
||||
c.session.lastRemoteSeen = time.Now().Unix()
|
||||
c.session.lastRemoteCipher = c.session.remoteCipher
|
||||
c.session.filter = new(wgReplay.Filter)
|
||||
c.session.filter = wgReplay.Filter{}
|
||||
}
|
||||
}
|
||||
c.session.remoteSessionId = sessionId
|
||||
|
@ -577,8 +565,8 @@ type udpSession struct {
|
|||
cipher cipher.AEAD
|
||||
remoteCipher cipher.AEAD
|
||||
lastRemoteCipher cipher.AEAD
|
||||
filter *wgReplay.Filter
|
||||
lastFilter *wgReplay.Filter
|
||||
filter wgReplay.Filter
|
||||
lastFilter wgReplay.Filter
|
||||
}
|
||||
|
||||
func (s *udpSession) nextPacketId() uint64 {
|
||||
|
@ -588,7 +576,6 @@ func (s *udpSession) nextPacketId() uint64 {
|
|||
func (m *Method) newUDPSession() *udpSession {
|
||||
session := &udpSession{
|
||||
sessionId: rand.Uint64(),
|
||||
filter: new(wgReplay.Filter),
|
||||
}
|
||||
if m.udpCipher == nil {
|
||||
sessionId := make([]byte, 8)
|
||||
|
|
195
protocol/shadowsocks/shadowaead_2022/service.go
Normal file
195
protocol/shadowsocks/shadowaead_2022/service.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package shadowaead_2022
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/replay"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
name string
|
||||
secureRNG io.Reader
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
psk []byte
|
||||
replayFilter replay.Filter
|
||||
handler shadowsocks.Handler
|
||||
}
|
||||
|
||||
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
s := &Service{
|
||||
name: method,
|
||||
psk: psk,
|
||||
secureRNG: secureRNG,
|
||||
replayFilter: replay.NewCuckoo(60),
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
if len(psk) != KeySaltSize {
|
||||
return nil, shadowaead.ErrBadKey
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "2022-blake3-aes-128-gcm":
|
||||
s.keyLength = 16
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
s.keyLength = 32
|
||||
s.constructor = newAESGCM
|
||||
// m.blockConstructor = newAES
|
||||
// m.udpBlockCipher = newAES(m.psk)
|
||||
case "2022-blake3-chacha20-poly1305":
|
||||
s.keyLength = 32
|
||||
s.constructor = newChacha20Poly1305
|
||||
// m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
||||
requestSalt := make([]byte, KeySaltSize)
|
||||
_, err := io.ReadFull(conn, requestSalt)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read request salt")
|
||||
}
|
||||
|
||||
if !s.replayFilter.Check(requestSalt) {
|
||||
return E.New("salt not unique")
|
||||
}
|
||||
|
||||
requestKey := Blake3DeriveKey(s.psk, requestSalt, s.keyLength)
|
||||
reader := shadowaead.NewReader(
|
||||
conn,
|
||||
s.constructor(common.Dup(requestKey)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
|
||||
headerType, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read header")
|
||||
}
|
||||
|
||||
if headerType != HeaderTypeClient {
|
||||
return ErrBadHeaderType
|
||||
}
|
||||
|
||||
var epoch uint64
|
||||
err = binary.Read(reader, binary.BigEndian, &epoch)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read timestamp")
|
||||
}
|
||||
if math.Abs(float64(time.Now().Unix()-int64(epoch))) > 30 {
|
||||
return ErrBadTimestamp
|
||||
}
|
||||
|
||||
destination, err := socks.AddressSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read destination")
|
||||
}
|
||||
|
||||
var paddingLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read padding length")
|
||||
}
|
||||
|
||||
if paddingLen > 0 {
|
||||
err = reader.Discard(int(paddingLen))
|
||||
if err != nil {
|
||||
return E.Cause(err, "discard padding")
|
||||
}
|
||||
}
|
||||
|
||||
metadata.Protocol = "shadowsocks"
|
||||
metadata.Destination = destination
|
||||
return s.handler.NewConnection(&serverConn{
|
||||
Service: s,
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
requestSalt: requestSalt,
|
||||
}, metadata)
|
||||
}
|
||||
|
||||
type serverConn struct {
|
||||
*Service
|
||||
net.Conn
|
||||
access sync.Mutex
|
||||
reader *shadowaead.Reader
|
||||
writer *shadowaead.Writer
|
||||
requestSalt []byte
|
||||
}
|
||||
|
||||
func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
||||
_salt := make([]byte, KeySaltSize)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(c.secureRNG, salt))
|
||||
key := Blake3DeriveKey(c.psk, salt, c.keyLength)
|
||||
writer := shadowaead.NewWriter(
|
||||
c.Conn,
|
||||
c.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
||||
common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer))
|
||||
common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must1(bufferedWriter.Write(c.requestSalt))
|
||||
c.requestSalt = nil
|
||||
|
||||
if len(payload) > 0 {
|
||||
_, err = bufferedWriter.Write(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = bufferedWriter.Flush()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.writer = writer
|
||||
n = len(payload)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(p []byte) (n int, err error) {
|
||||
if c.writer != nil {
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
c.access.Lock()
|
||||
if c.writer != nil {
|
||||
c.access.Unlock()
|
||||
return c.writer.Write(p)
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.writeResponse(p)
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if c.writer != nil {
|
||||
return rw.ReadFrom0(c, r)
|
||||
}
|
||||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return c.reader.WriteTo(w)
|
||||
}
|
|
@ -2,13 +2,13 @@ package socks
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type PacketConn interface {
|
||||
|
@ -47,26 +47,32 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
|
|||
|
||||
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
_buffer := buf.StackNew()
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
for {
|
||||
destination, err := conn.ReadPacket(buffer)
|
||||
data.FullReset()
|
||||
destination, err := conn.ReadPacket(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
err = dest.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}, func() error {
|
||||
_buffer := buf.StackNew()
|
||||
_buffer := buf.StackNewMax()
|
||||
buffer := common.Dup(_buffer)
|
||||
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
|
||||
for {
|
||||
destination, err := dest.ReadPacket(buffer)
|
||||
data.FullReset()
|
||||
destination, err := dest.ReadPacket(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(data.Len())
|
||||
err = conn.WritePacket(buffer, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -125,7 +131,8 @@ func (c *associatePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error
|
|||
|
||||
func (c *associatePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
common.Must(header.WriteZeroN(3))
|
||||
common.Must(AddressSerializer.WriteAddrPort(header, addrPort))
|
||||
buffer = buffer.WriteBufferAtFirst(header)
|
||||
|
|
|
@ -83,7 +83,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
|
|||
if err != nil {
|
||||
return E.Cause(err, "read user auth request")
|
||||
}
|
||||
response := new(UsernamePasswordAuthResponse)
|
||||
response := &UsernamePasswordAuthResponse{}
|
||||
if authenticator.Verify(usernamePasswordAuthRequest.Username, usernamePasswordAuthRequest.Password) {
|
||||
response.Status = UsernamePasswordStatusSuccess
|
||||
} else {
|
||||
|
@ -109,6 +109,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
|
|||
if err != nil {
|
||||
return E.Cause(err, "write socks response")
|
||||
}
|
||||
metadata.Protocol = "socks"
|
||||
metadata.Destination = request.Destination
|
||||
return handler.NewConnection(conn, metadata)
|
||||
case CommandUDPAssociate:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue