Add client and service

This commit is contained in:
世界 2023-02-20 12:52:48 +08:00
parent 320d58c57a
commit 6c9bdfc858
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
16 changed files with 1173 additions and 0 deletions

104
client.go Normal file
View file

@ -0,0 +1,104 @@
package shadowtls
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/hex"
"net"
"os"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type ClientConfig struct {
Version int
Password string
Server M.Socksaddr
Dialer N.Dialer
TLSHandshake TLSHandshakeFunc
Logger logger.ContextLogger
}
type Client struct {
version int
password string
server M.Socksaddr
dialer N.Dialer
tlsHandshake TLSHandshakeFunc
logger logger.ContextLogger
}
func NewClient(config ClientConfig) (*Client, error) {
client := &Client{
version: config.Version,
password: config.Password,
server: config.Server,
dialer: config.Dialer,
tlsHandshake: config.TLSHandshake,
logger: config.Logger,
}
if !client.server.IsValid() || client.dialer == nil || client.tlsHandshake == nil {
return nil, os.ErrInvalid
}
switch client.version {
case 1, 2, 3:
default:
return nil, E.New("unknown protocol version: ", client.version)
}
if client.dialer == nil {
client.dialer = N.SystemDialer
}
return client, nil
}
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
if err != nil {
return nil, err
}
switch c.version {
default:
fallthrough
case 1:
err = c.tlsHandshake(ctx, conn, nil)
if err != nil {
return nil, err
}
c.logger.TraceContext(ctx, "clint handshake finished")
return conn, nil
case 2:
hashConn := newHashReadConn(conn, c.password)
err = c.tlsHandshake(ctx, hashConn, nil)
if err != nil {
return nil, err
}
c.logger.TraceContext(ctx, "clint handshake finished")
return newClientConn(hashConn), nil
case 3:
stream := newStreamWrapper(conn, c.password)
err = c.tlsHandshake(ctx, stream, generateSessionID(c.password))
if err != nil {
return nil, err
}
c.logger.TraceContext(ctx, "handshake success")
authorized, serverRandom, readHMAC := stream.Authorized()
if !authorized {
return nil, E.New("traffic hijacked or TLS1.3 is not supported")
}
if debug.Enabled {
c.logger.TraceContext(ctx, "authorized, server random extracted: ", hex.EncodeToString(serverRandom))
}
hmacAdd := hmac.New(sha1.New, []byte(c.password))
hmacAdd.Write(serverRandom)
hmacAdd.Write([]byte("C"))
hmacVerify := hmac.New(sha1.New, []byte(c.password))
hmacVerify.Write(serverRandom)
hmacVerify.Write([]byte("S"))
return newVerifiedConn(conn, hmacAdd, hmacVerify, readHMAC), nil
}
}

1
go.mod
View file

@ -3,6 +3,7 @@ module github.com/sagernet/sing-shadowtls
go 1.18
require (
github.com/sagernet/sing v0.1.6
golang.org/x/crypto v0.6.0
golang.org/x/sys v0.5.0
)

2
go.sum
View file

@ -1,3 +1,5 @@
github.com/sagernet/sing v0.1.6 h1:Qy63OUfKpcqKjfd5rPmUlj0RGjHZSK/PJn0duyCCsRg=
github.com/sagernet/sing v0.1.6/go.mod h1:JLSXsPTGRJFo/3X7EcAOCUgJH2/gAoxSJgBsnCZRp/w=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=

196
service.go Normal file
View file

@ -0,0 +1,196 @@
package shadowtls
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/hex"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
)
type ServiceConfig struct {
Version int
Password string
HandshakeServer M.Socksaddr
HandshakeDialer N.Dialer
Handler Handler
Logger logger.ContextLogger
}
type Handler interface {
N.TCPConnectionHandler
E.Handler
}
type Service struct {
version int
password string
handshakeServer M.Socksaddr
handshakeDialer N.Dialer
handler Handler
logger logger.ContextLogger
}
func NewService(config ServiceConfig) (*Service, error) {
service := &Service{
version: config.Version,
password: config.Password,
handshakeServer: config.HandshakeServer,
handshakeDialer: config.HandshakeDialer,
handler: config.Handler,
logger: config.Logger,
}
if !service.handshakeServer.IsValid() || service.handler == nil || service.logger == nil {
return nil, os.ErrInvalid
}
switch config.Version {
case 1, 2, 3:
default:
return nil, E.New("unknown protocol version: ", config.Version)
}
return service, nil
}
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeServer)
if err != nil {
return E.Cause(err, "server handshake")
}
switch s.version {
default:
fallthrough
case 1:
var group task.Group
group.Append("client handshake", func(ctx context.Context) error {
return copyUntilHandshakeFinished(handshakeConn, conn)
})
group.Append("server handshake", func(ctx context.Context) error {
return copyUntilHandshakeFinished(conn, handshakeConn)
})
group.FastFail()
group.Cleanup(func() {
handshakeConn.Close()
})
err = group.Run(ctx)
if err != nil {
return err
}
s.logger.TraceContext(ctx, "handshake finished")
return s.handler.NewConnection(ctx, conn, metadata)
case 2:
hashConn := newHashWriteConn(conn, s.password)
go bufio.Copy(hashConn, handshakeConn)
var request *buf.Buffer
request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, conn, hashConn, 2)
if err == nil {
s.logger.TraceContext(ctx, "handshake finished")
handshakeConn.Close()
return s.handler.NewConnection(ctx, bufio.NewCachedConn(newConn(conn), request), metadata)
} else if err == os.ErrPermission {
s.logger.WarnContext(ctx, "fallback connection")
hashConn.Fallback()
return common.Error(bufio.Copy(handshakeConn, conn))
} else {
return err
}
case 3:
var clientHelloFrame *buf.Buffer
clientHelloFrame, err = extractFrame(conn)
if err != nil {
return E.Cause(err, "read client handshake")
}
_, err = handshakeConn.Write(clientHelloFrame.Bytes())
if err != nil {
clientHelloFrame.Release()
return E.Cause(err, "write client handshake")
}
err = verifyClientHello(clientHelloFrame.Bytes(), s.password)
if err != nil {
s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed"))
return bufio.CopyConn(ctx, conn, handshakeConn)
}
s.logger.TraceContext(ctx, "client hello verify success")
clientHelloFrame.Release()
var serverHelloFrame *buf.Buffer
serverHelloFrame, err = extractFrame(handshakeConn)
if err != nil {
return E.Cause(err, "read server handshake")
}
_, err = conn.Write(serverHelloFrame.Bytes())
if err != nil {
serverHelloFrame.Release()
return E.Cause(err, "write server handshake")
}
serverRandom := extractServerRandom(serverHelloFrame.Bytes())
if serverRandom == nil {
s.logger.WarnContext(ctx, "server random extract failed, will copy bidirectional")
return bufio.CopyConn(ctx, conn, handshakeConn)
}
if !isServerHelloSupportTLS13(serverHelloFrame.Bytes()) {
s.logger.WarnContext(ctx, "TLS 1.3 is not supported, will copy bidirectional")
return bufio.CopyConn(ctx, conn, handshakeConn)
}
serverHelloFrame.Release()
if debug.Enabled {
s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom))
}
hmacWrite := hmac.New(sha1.New, []byte(s.password))
hmacWrite.Write(serverRandom)
hmacAdd := hmac.New(sha1.New, []byte(s.password))
hmacAdd.Write(serverRandom)
hmacAdd.Write([]byte("S"))
hmacVerify := hmac.New(sha1.New, []byte(s.password))
hmacVerifyReset := func() {
hmacVerify.Reset()
hmacVerify.Write(serverRandom)
hmacVerify.Write([]byte("C"))
}
var clientFirstFrame *buf.Buffer
var group task.Group
var handshakeFinished bool
group.Append("client handshake relay", func(ctx context.Context) error {
clientFrame, cErr := copyByFrameUntilHMACMatches(conn, handshakeConn, hmacVerify, hmacVerifyReset)
if cErr == nil {
clientFirstFrame = clientFrame
handshakeFinished = true
handshakeConn.Close()
}
return cErr
})
group.Append("server handshake relay", func(ctx context.Context) error {
cErr := copyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite)
if E.IsClosedOrCanceled(cErr) && handshakeFinished {
return nil
}
return cErr
})
group.Cleanup(func() {
handshakeConn.Close()
})
err = group.Run(ctx)
if err != nil {
return E.Cause(err, "handshake relay")
}
s.logger.TraceContext(ctx, "handshake relay finished")
return s.handler.NewConnection(ctx, bufio.NewCachedConn(newVerifiedConn(conn, hmacAdd, hmacVerify, nil), clientFirstFrame), metadata)
}
}

20
tls.go Normal file
View file

@ -0,0 +1,20 @@
//go:build go1.20
package shadowtls
import (
sTLS "github.com/sagernet/sing-shadowtls/tls"
)
type (
sTLSConfig = sTLS.Config
sTLSConnectionState = sTLS.ConnectionState
sTLSConn = sTLS.Conn
sTLSCurveID = sTLS.CurveID
sTLSRenegotiationSupport = sTLS.RenegotiationSupport
)
var (
sTLSCipherSuites = sTLS.CipherSuites
sTLSClient = sTLS.Client
)

18
tls_compact.go Normal file
View file

@ -0,0 +1,18 @@
//go:build !go1.20
package shadowtls
import sTLS "github.com/sagernet/sing-shadowtls/tls_compact"
type (
sTLSConfig = sTLS.Config
sTLSConnectionState = sTLS.ConnectionState
sTLSConn = sTLS.Conn
sTLSCurveID = sTLS.CurveID
sTLSRenegotiationSupport = sTLS.RenegotiationSupport
)
var (
sTLSCipherSuites = sTLS.CipherSuites
sTLSClient = sTLS.Client
)

44
tls_wrapper.go Normal file
View file

@ -0,0 +1,44 @@
package shadowtls
import (
"context"
"crypto/tls"
"net"
"github.com/sagernet/sing/common"
)
type (
TLSSessionIDGeneratorFunc func(clientHello []byte, sessionID []byte) error
TLSHandshakeFunc func(
ctx context.Context,
conn net.Conn,
sessionIDGenerator TLSSessionIDGeneratorFunc, // for shadow-tls version 3
) error
)
func DefaultTLSHandshakeFunc(password string, config *tls.Config) TLSHandshakeFunc {
return func(ctx context.Context, conn net.Conn, sessionIDGenerator TLSSessionIDGeneratorFunc) error {
tlsConfig := &sTLSConfig{
Rand: config.Rand,
Time: config.Time,
VerifyPeerCertificate: config.VerifyPeerCertificate,
RootCAs: config.RootCAs,
NextProtos: config.NextProtos,
ServerName: config.ServerName,
InsecureSkipVerify: config.InsecureSkipVerify,
CipherSuites: config.CipherSuites,
MinVersion: config.MinVersion,
MaxVersion: config.MaxVersion,
CurvePreferences: common.Map(config.CurvePreferences, func(it tls.CurveID) sTLSCurveID {
return sTLSCurveID(it)
}),
SessionTicketsDisabled: config.SessionTicketsDisabled,
Renegotiation: sTLSRenegotiationSupport(config.Renegotiation),
SessionIDGenerator: generateSessionID(password),
}
tlsConn := sTLSClient(conn, tlsConfig)
return tlsConn.HandshakeContext(ctx)
}
}

37
v1_server.go Normal file
View file

@ -0,0 +1,37 @@
package shadowtls
import (
"bytes"
"encoding/binary"
"io"
E "github.com/sagernet/sing/common/exceptions"
)
func copyUntilHandshakeFinished(dst io.Writer, src io.Reader) error {
var hasSeenChangeCipherSpec bool
var tlsHdr [tlsHeaderSize]byte
for {
_, err := io.ReadFull(src, tlsHdr[:])
if err != nil {
return err
}
length := binary.BigEndian.Uint16(tlsHdr[3:])
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), io.LimitReader(src, int64(length))))
if err != nil {
return err
}
if tlsHdr[0] != handshake {
if tlsHdr[0] != changeCipherSpec {
return E.New("unexpected tls frame type: ", tlsHdr[0])
}
if !hasSeenChangeCipherSpec {
hasSeenChangeCipherSpec = true
continue
}
}
if hasSeenChangeCipherSpec {
return nil
}
}
}

37
v2_client.go Normal file
View file

@ -0,0 +1,37 @@
package shadowtls
import (
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
)
type clientConn struct {
*shadowConn
hashConn *hashReadConn
}
func newClientConn(hashConn *hashReadConn) *clientConn {
return &clientConn{newConn(hashConn.Conn), hashConn}
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if c.hashConn != nil {
sum := c.hashConn.Sum()
c.hashConn = nil
_, err = bufio.WriteVectorised(c.shadowConn, [][]byte{sum, p})
if err == nil {
n = len(p)
}
return
}
return c.shadowConn.Write(p)
}
func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error {
if c.hashConn != nil {
sum := c.hashConn.Sum()
c.hashConn = nil
return c.shadowConn.WriteVectorised(append([]*buf.Buffer{buf.As(sum)}, buffers...))
}
return c.shadowConn.WriteVectorised(buffers)
}

95
v2_conn.go Normal file
View file

@ -0,0 +1,95 @@
package shadowtls
import (
"crypto/tls"
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
type shadowConn struct {
net.Conn
writer N.VectorisedWriter
readRemaining int
}
func newConn(conn net.Conn) *shadowConn {
return &shadowConn{
Conn: conn,
writer: bufio.NewVectorisedWriter(conn),
}
}
func (c *shadowConn) Read(p []byte) (n int, err error) {
if c.readRemaining > 0 {
if len(p) > c.readRemaining {
p = p[:c.readRemaining]
}
n, err = c.Conn.Read(p)
c.readRemaining -= n
return
}
var tlsHeader [5]byte
_, err = io.ReadFull(c.Conn, common.Dup(tlsHeader[:]))
if err != nil {
return
}
length := int(binary.BigEndian.Uint16(tlsHeader[3:5]))
if tlsHeader[0] != 23 {
return 0, E.New("unexpected TLS record type: ", tlsHeader[0])
}
readLen := len(p)
if readLen > length {
readLen = length
}
n, err = c.Conn.Read(p[:readLen])
if err != nil {
return
}
c.readRemaining = length - n
return
}
func (c *shadowConn) Write(p []byte) (n int, err error) {
var header [tlsHeaderSize]byte
defer common.KeepAlive(header)
header[0] = 23
for len(p) > 16384 {
binary.BigEndian.PutUint16(header[1:3], tls.VersionTLS12)
binary.BigEndian.PutUint16(header[3:5], uint16(16384))
_, err = bufio.WriteVectorised(c.writer, [][]byte{common.Dup(header[:]), p[:16384]})
common.KeepAlive(header)
if err != nil {
return
}
n += 16384
p = p[16384:]
}
binary.BigEndian.PutUint16(header[1:3], tls.VersionTLS12)
binary.BigEndian.PutUint16(header[3:5], uint16(len(p)))
_, err = bufio.WriteVectorised(c.writer, [][]byte{common.Dup(header[:]), p})
if err == nil {
n += len(p)
}
return
}
func (c *shadowConn) WriteVectorised(buffers []*buf.Buffer) error {
var header [tlsHeaderSize]byte
defer common.KeepAlive(header)
header[0] = 23
dataLen := buf.LenMulti(buffers)
binary.BigEndian.PutUint16(header[1:3], tls.VersionTLS12)
binary.BigEndian.PutUint16(header[3:5], uint16(dataLen))
return c.writer.WriteVectorised(append([]*buf.Buffer{buf.As(header[:])}, buffers...))
}
func (c *shadowConn) Upstream() any {
return c.Conn
}

74
v2_hash.go Normal file
View file

@ -0,0 +1,74 @@
package shadowtls
import (
"crypto/hmac"
"crypto/sha1"
"hash"
"net"
)
type hashReadConn struct {
net.Conn
hmac hash.Hash
}
func newHashReadConn(conn net.Conn, password string) *hashReadConn {
return &hashReadConn{
conn,
hmac.New(sha1.New, []byte(password)),
}
}
func (c *hashReadConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
return
}
_, err = c.hmac.Write(b[:n])
return
}
func (c *hashReadConn) Sum() []byte {
return c.hmac.Sum(nil)[:8]
}
type hashWriteConn struct {
net.Conn
hmac hash.Hash
hasContent bool
lastSum []byte
}
func newHashWriteConn(conn net.Conn, password string) *hashWriteConn {
return &hashWriteConn{
Conn: conn,
hmac: hmac.New(sha1.New, []byte(password)),
}
}
func (c *hashWriteConn) Write(p []byte) (n int, err error) {
if c.hmac != nil {
if c.hasContent {
c.lastSum = c.Sum()
}
c.hmac.Write(p)
c.hasContent = true
}
return c.Conn.Write(p)
}
func (c *hashWriteConn) Sum() []byte {
return c.hmac.Sum(nil)[:8]
}
func (c *hashWriteConn) LastSum() []byte {
return c.lastSum
}
func (c *hashWriteConn) Fallback() {
c.hmac = nil
}
func (c *hashWriteConn) HasContent() bool {
return c.hasContent
}

58
v2_server.go Normal file
View file

@ -0,0 +1,58 @@
package shadowtls
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"os"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/logger"
)
func copyUntilHandshakeFinishedV2(ctx context.Context, logger logger.ContextLogger, dst net.Conn, src io.Reader, hash *hashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
var tlsHdr [tlsHeaderSize]byte
var applicationDataCount int
for {
_, err := io.ReadFull(src, tlsHdr[:])
if err != nil {
return nil, err
}
length := binary.BigEndian.Uint16(tlsHdr[3:])
if tlsHdr[0] == applicationData {
data := buf.NewSize(int(length))
_, err = data.ReadFullFrom(src, int(length))
if err != nil {
data.Release()
return nil, err
}
if hash.HasContent() && length >= 8 {
checksum := hash.Sum()
if bytes.Equal(data.To(8), checksum) {
logger.TraceContext(ctx, "match current hashcode")
data.Advance(8)
return data, nil
} else if hash.LastSum() != nil && bytes.Equal(data.To(8), hash.LastSum()) {
logger.TraceContext(ctx, "match last hashcode")
data.Advance(8)
return data, nil
} else {
logger.TraceContext(ctx, "hashcode mismatch")
}
}
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data))
data.Release()
applicationDataCount++
} else {
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), io.LimitReader(src, int64(length))))
}
if err != nil {
return nil, err
}
if applicationDataCount > fallbackAfter {
return nil, os.ErrPermission
}
}
}

115
v3_client.go Normal file
View file

@ -0,0 +1,115 @@
package shadowtls
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"encoding/binary"
"hash"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
)
func generateSessionID(password string) func(clientHello []byte, sessionID []byte) error {
return func(clientHello []byte, sessionID []byte) error {
const sessionIDStart = 1 + 3 + 2 + tlsRandomSize + 1
if len(clientHello) < sessionIDStart+tlsSessionIDSize {
return E.New("unexpected client hello length")
}
_, err := rand.Read(sessionID[:tlsSessionIDSize-hmacSize])
if err != nil {
return err
}
hmacSHA1Hash := hmac.New(sha1.New, []byte(password))
hmacSHA1Hash.Write(clientHello[:sessionIDStart])
hmacSHA1Hash.Write(sessionID)
hmacSHA1Hash.Write(clientHello[sessionIDStart+tlsSessionIDSize:])
copy(sessionID[tlsSessionIDSize-hmacSize:], hmacSHA1Hash.Sum(nil)[:hmacSize])
return nil
}
}
type streamWrapper struct {
net.Conn
password string
buffer *buf.Buffer
serverRandom []byte
readHMAC hash.Hash
readHMACKey []byte
authorized bool
}
func newStreamWrapper(conn net.Conn, password string) *streamWrapper {
return &streamWrapper{
Conn: conn,
password: password,
}
}
func (w *streamWrapper) Authorized() (bool, []byte, hash.Hash) {
return w.authorized, w.serverRandom, w.readHMAC
}
func (w *streamWrapper) Read(p []byte) (n int, err error) {
if w.buffer != nil {
if !w.buffer.IsEmpty() {
return w.buffer.Read(p)
}
w.buffer.Release()
w.buffer = nil
}
var tlsHeader [tlsHeaderSize]byte
_, err = io.ReadFull(w.Conn, tlsHeader[:])
if err != nil {
return
}
length := int(binary.BigEndian.Uint16(tlsHeader[3:tlsHeaderSize]))
w.buffer = buf.NewSize(tlsHeaderSize + length)
common.Must1(w.buffer.Write(tlsHeader[:]))
_, err = w.buffer.ReadFullFrom(w.Conn, length)
if err != nil {
return
}
buffer := w.buffer.Bytes()
switch tlsHeader[0] {
case handshake:
if len(buffer) > serverRandomIndex+tlsRandomSize && buffer[5] == serverHello {
w.serverRandom = make([]byte, tlsRandomSize)
copy(w.serverRandom, buffer[serverRandomIndex:serverRandomIndex+tlsRandomSize])
w.readHMAC = hmac.New(sha1.New, []byte(w.password))
w.readHMAC.Write(w.serverRandom)
w.readHMACKey = kdf(w.password, w.serverRandom)
}
case applicationData:
w.authorized = false
if len(buffer) > tlsHmacHeaderSize && w.readHMAC != nil {
w.readHMAC.Write(buffer[tlsHmacHeaderSize:])
if hmac.Equal(w.readHMAC.Sum(nil)[:hmacSize], buffer[tlsHeaderSize:tlsHmacHeaderSize]) {
xorSlice(buffer[tlsHmacHeaderSize:], w.readHMACKey)
copy(buffer[hmacSize:], buffer[:tlsHeaderSize])
binary.BigEndian.PutUint16(buffer[hmacSize+3:], uint16(len(buffer)-tlsHmacHeaderSize))
w.buffer.Advance(hmacSize)
w.authorized = true
}
}
}
return w.buffer.Read(p)
}
func kdf(password string, serverRandom []byte) []byte {
hasher := sha256.New()
hasher.Write([]byte(password))
hasher.Write(serverRandom)
return hasher.Sum(nil)
}
func xorSlice(data []byte, key []byte) {
for i := range data {
data[i] ^= key[i%len(key)]
}
}

171
v3_conn.go Normal file
View file

@ -0,0 +1,171 @@
package shadowtls
import (
"bytes"
"crypto/rand"
"encoding/binary"
"hash"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
type verifiedConn struct {
net.Conn
writer N.VectorisedWriter
hmacAdd hash.Hash
hmacVerify hash.Hash
hmacIgnore hash.Hash
buffer *buf.Buffer
}
func newVerifiedConn(
conn net.Conn,
hmacAdd hash.Hash,
hmacVerify hash.Hash,
hmacIgnore hash.Hash,
) *verifiedConn {
return &verifiedConn{
Conn: conn,
writer: bufio.NewVectorisedWriter(conn),
hmacAdd: hmacAdd,
hmacVerify: hmacVerify,
hmacIgnore: hmacIgnore,
}
}
func (c *verifiedConn) Read(b []byte) (n int, err error) {
if c.buffer != nil {
if !c.buffer.IsEmpty() {
return c.buffer.Read(b)
}
c.buffer.Release()
c.buffer = nil
}
for {
var tlsHeader [tlsHeaderSize]byte
_, err = io.ReadFull(c.Conn, tlsHeader[:])
if err != nil {
sendAlert(c.Conn)
return
}
length := int(binary.BigEndian.Uint16(tlsHeader[3:tlsHeaderSize]))
c.buffer = buf.NewSize(tlsHeaderSize + length)
common.Must1(c.buffer.Write(tlsHeader[:]))
_, err = c.buffer.ReadFullFrom(c.Conn, length)
if err != nil {
return
}
buffer := c.buffer.Bytes()
switch buffer[0] {
case alert:
err = E.Cause(net.ErrClosed, "remote alert")
return
case applicationData:
if c.hmacIgnore != nil {
if verifyApplicationData(buffer, c.hmacIgnore, false) {
c.buffer.Release()
c.buffer = nil
continue
} else {
c.hmacIgnore = nil
}
}
if !verifyApplicationData(buffer, c.hmacVerify, true) {
sendAlert(c.Conn)
err = E.New("application data verification failed")
return
}
c.buffer.Advance(tlsHmacHeaderSize)
default:
sendAlert(c.Conn)
err = E.New("unexpected TLS record type: ", buffer[0])
return
}
return c.buffer.Read(b)
}
}
func (c *verifiedConn) Write(p []byte) (n int, err error) {
pTotal := len(p)
for len(p) > 0 {
var pWrite []byte
if len(p) > 16384 {
pWrite = p[:16384]
p = p[16384:]
} else {
pWrite = p
p = nil
}
_, err = c.write(pWrite)
}
if err == nil {
n = pTotal
}
return
}
func (c *verifiedConn) write(p []byte) (n int, err error) {
var header [tlsHmacHeaderSize]byte
header[0] = applicationData
header[1] = 3
header[2] = 3
binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(len(p)))
c.hmacAdd.Write(p)
hmacHash := c.hmacAdd.Sum(nil)[:hmacSize]
c.hmacAdd.Write(hmacHash)
copy(header[tlsHeaderSize:], hmacHash)
_, err = bufio.WriteVectorised(c.writer, [][]byte{common.Dup(header[:]), p})
if err == nil {
n = len(p)
}
return
}
func (c *verifiedConn) WriteVectorised(buffers []*buf.Buffer) error {
var header [tlsHmacHeaderSize]byte
header[0] = applicationData
header[1] = 3
header[2] = 3
binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(buf.LenMulti(buffers)))
for _, buffer := range buffers {
c.hmacAdd.Write(buffer.Bytes())
}
c.hmacAdd.Write(c.hmacAdd.Sum(nil)[:hmacSize])
copy(header[tlsHeaderSize:], c.hmacAdd.Sum(nil)[:hmacSize])
return c.writer.WriteVectorised(append([]*buf.Buffer{buf.As(header[:])}, buffers...))
}
func verifyApplicationData(frame []byte, hmac hash.Hash, update bool) bool {
if frame[1] != 3 || frame[2] != 3 || len(frame) < tlsHmacHeaderSize {
return false
}
hmac.Write(frame[tlsHmacHeaderSize:])
hmacHash := hmac.Sum(nil)[:hmacSize]
if update {
hmac.Write(hmacHash)
}
return bytes.Equal(frame[tlsHeaderSize:tlsHeaderSize+hmacSize], hmacHash)
}
func sendAlert(writer io.Writer) {
const recordSize = 31
record := [recordSize]byte{
alert,
3,
3,
0,
recordSize - tlsHeaderSize,
}
_, err := rand.Read(record[tlsHeaderSize:])
if err != nil {
return
}
writer.Write(record[:])
}

20
v3_constrat.go Normal file
View file

@ -0,0 +1,20 @@
package shadowtls
const (
tlsRandomSize = 32
tlsHeaderSize = 5
tlsSessionIDSize = 32
clientHello = 1
serverHello = 2
changeCipherSpec = 20
alert = 21
handshake = 22
applicationData = 23
serverRandomIndex = tlsHeaderSize + 1 + 3 + 2
sessionIDLengthIndex = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize
tlsHmacHeaderSize = tlsHeaderSize + hmacSize
hmacSize = 4
)

181
v3_server.go Normal file
View file

@ -0,0 +1,181 @@
package shadowtls
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/binary"
"hash"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
)
func extractFrame(conn net.Conn) (*buf.Buffer, error) {
var tlsHeader [tlsHeaderSize]byte
_, err := io.ReadFull(conn, tlsHeader[:])
if err != nil {
return nil, err
}
length := int(binary.BigEndian.Uint16(tlsHeader[3:]))
buffer := buf.NewSize(tlsHeaderSize + length)
common.Must1(buffer.Write(tlsHeader[:]))
_, err = buffer.ReadFullFrom(conn, length)
if err != nil {
buffer.Release()
}
return buffer, err
}
func verifyClientHello(frame []byte, password string) error {
const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize
const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize
if len(frame) < minLen {
return io.ErrUnexpectedEOF
} else if frame[0] != handshake {
return E.New("unexpected record type")
} else if frame[5] != clientHello {
return E.New("unexpected handshake type")
} else if frame[sessionIDLengthIndex] != tlsSessionIDSize {
return E.New("unexpected session id length")
}
hmacSHA1Hash := hmac.New(sha1.New, []byte(password))
hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex])
hmacSHA1Hash.Write(rw.ZeroBytes[:4])
hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:])
if !hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) {
return E.New("hmac mismatch")
}
return nil
}
func extractServerRandom(frame []byte) []byte {
const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize
if len(frame) < minLen || frame[0] != handshake || frame[5] != serverHello {
return nil
}
serverRandom := make([]byte, tlsRandomSize)
copy(serverRandom, frame[serverRandomIndex:serverRandomIndex+tlsRandomSize])
return serverRandom
}
func isServerHelloSupportTLS13(frame []byte) bool {
if len(frame) < sessionIDLengthIndex {
return false
}
reader := bytes.NewReader(frame[sessionIDLengthIndex:])
var sessionIdLength uint8
err := binary.Read(reader, binary.BigEndian, &sessionIdLength)
if err != nil {
return false
}
_, err = io.CopyN(io.Discard, reader, int64(sessionIdLength))
if err != nil {
return false
}
_, err = io.CopyN(io.Discard, reader, 3)
if err != nil {
return false
}
var extensionListLength uint16
err = binary.Read(reader, binary.BigEndian, &extensionListLength)
if err != nil {
return false
}
for i := uint16(0); i < extensionListLength; i++ {
var extensionType uint16
err = binary.Read(reader, binary.BigEndian, &extensionType)
if err != nil {
return false
}
var extensionLength uint16
err = binary.Read(reader, binary.BigEndian, &extensionLength)
if err != nil {
return false
}
if extensionType != 43 {
_, err = io.CopyN(io.Discard, reader, int64(extensionLength))
if err != nil {
return false
}
continue
}
if extensionLength != 2 {
return false
}
var extensionValue uint16
err = binary.Read(reader, binary.BigEndian, &extensionValue)
if err != nil {
return false
}
return extensionValue == 0x0304
}
return false
}
func copyByFrameUntilHMACMatches(conn net.Conn, handshakeConn net.Conn, hmacVerify hash.Hash, hmacReset func()) (*buf.Buffer, error) {
for {
frameBuffer, err := extractFrame(conn)
if err != nil {
return nil, E.Cause(err, "read client record")
}
frame := frameBuffer.Bytes()
if len(frame) > tlsHmacHeaderSize && frame[0] == applicationData {
hmacReset()
hmacVerify.Write(frame[tlsHmacHeaderSize:])
hmacHash := hmacVerify.Sum(nil)[:4]
if bytes.Equal(hmacHash, frame[tlsHeaderSize:tlsHmacHeaderSize]) {
hmacReset()
hmacVerify.Write(frame[tlsHmacHeaderSize:])
hmacVerify.Write(frame[tlsHeaderSize:tlsHmacHeaderSize])
frameBuffer.Advance(tlsHmacHeaderSize)
return frameBuffer, nil
}
}
_, err = handshakeConn.Write(frame)
frameBuffer.Release()
if err != nil {
return nil, E.Cause(err, "write clint frame")
}
}
}
func copyByFrameWithModification(conn net.Conn, handshakeConn net.Conn, password string, serverRandom []byte, hmacWrite hash.Hash) error {
writeKey := kdf(password, serverRandom)
writer := bufio.NewVectorisedWriter(handshakeConn)
for {
frameBuffer, err := extractFrame(conn)
if err != nil {
return E.Cause(err, "read server record")
}
frame := frameBuffer.Bytes()
if frame[0] == applicationData {
xorSlice(frame[tlsHeaderSize:], writeKey)
hmacWrite.Write(frame[tlsHeaderSize:])
binary.BigEndian.PutUint16(frame[3:], uint16(len(frame)-tlsHeaderSize+hmacSize))
hmacHash := hmacWrite.Sum(nil)[:4]
_, err = bufio.WriteVectorised(writer, [][]byte{frame[:tlsHeaderSize], hmacHash, frame[tlsHeaderSize:]})
frameBuffer.Release()
if err != nil {
return E.Cause(err, "write modified server frame")
}
} else {
_, err = handshakeConn.Write(frame)
frameBuffer.Release()
if err != nil {
return E.Cause(err, "write server frame")
}
}
}
}