From 6c9bdfc85869f9cc75043c1976a1f6e90350802d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 20 Feb 2023 12:52:48 +0800 Subject: [PATCH] Add client and service --- client.go | 104 ++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + service.go | 196 +++++++++++++++++++++++++++++++++++++++++++++++++ tls.go | 20 +++++ tls_compact.go | 18 +++++ tls_wrapper.go | 44 +++++++++++ v1_server.go | 37 ++++++++++ v2_client.go | 37 ++++++++++ v2_conn.go | 95 ++++++++++++++++++++++++ v2_hash.go | 74 +++++++++++++++++++ v2_server.go | 58 +++++++++++++++ v3_client.go | 115 +++++++++++++++++++++++++++++ v3_conn.go | 171 ++++++++++++++++++++++++++++++++++++++++++ v3_constrat.go | 20 +++++ v3_server.go | 181 +++++++++++++++++++++++++++++++++++++++++++++ 16 files changed, 1173 insertions(+) create mode 100644 client.go create mode 100644 service.go create mode 100644 tls.go create mode 100644 tls_compact.go create mode 100644 tls_wrapper.go create mode 100644 v1_server.go create mode 100644 v2_client.go create mode 100644 v2_conn.go create mode 100644 v2_hash.go create mode 100644 v2_server.go create mode 100644 v3_client.go create mode 100644 v3_conn.go create mode 100644 v3_constrat.go create mode 100644 v3_server.go diff --git a/client.go b/client.go new file mode 100644 index 0000000..fe44c8a --- /dev/null +++ b/client.go @@ -0,0 +1,104 @@ +package shadowtls + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "encoding/hex" + "net" + "os" + + "github.com/sagernet/sing/common/debug" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type ClientConfig struct { + Version int + Password string + Server M.Socksaddr + Dialer N.Dialer + TLSHandshake TLSHandshakeFunc + Logger logger.ContextLogger +} + +type Client struct { + version int + password string + server M.Socksaddr + dialer N.Dialer + tlsHandshake TLSHandshakeFunc + logger logger.ContextLogger +} + +func NewClient(config ClientConfig) (*Client, error) { + client := &Client{ + version: config.Version, + password: config.Password, + server: config.Server, + dialer: config.Dialer, + tlsHandshake: config.TLSHandshake, + logger: config.Logger, + } + if !client.server.IsValid() || client.dialer == nil || client.tlsHandshake == nil { + return nil, os.ErrInvalid + } + switch client.version { + case 1, 2, 3: + default: + return nil, E.New("unknown protocol version: ", client.version) + } + if client.dialer == nil { + client.dialer = N.SystemDialer + } + return client, nil +} + +func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) + if err != nil { + return nil, err + } + switch c.version { + default: + fallthrough + case 1: + err = c.tlsHandshake(ctx, conn, nil) + if err != nil { + return nil, err + } + c.logger.TraceContext(ctx, "clint handshake finished") + return conn, nil + case 2: + hashConn := newHashReadConn(conn, c.password) + err = c.tlsHandshake(ctx, hashConn, nil) + if err != nil { + return nil, err + } + c.logger.TraceContext(ctx, "clint handshake finished") + return newClientConn(hashConn), nil + case 3: + stream := newStreamWrapper(conn, c.password) + err = c.tlsHandshake(ctx, stream, generateSessionID(c.password)) + if err != nil { + return nil, err + } + c.logger.TraceContext(ctx, "handshake success") + authorized, serverRandom, readHMAC := stream.Authorized() + if !authorized { + return nil, E.New("traffic hijacked or TLS1.3 is not supported") + } + if debug.Enabled { + c.logger.TraceContext(ctx, "authorized, server random extracted: ", hex.EncodeToString(serverRandom)) + } + hmacAdd := hmac.New(sha1.New, []byte(c.password)) + hmacAdd.Write(serverRandom) + hmacAdd.Write([]byte("C")) + hmacVerify := hmac.New(sha1.New, []byte(c.password)) + hmacVerify.Write(serverRandom) + hmacVerify.Write([]byte("S")) + return newVerifiedConn(conn, hmacAdd, hmacVerify, readHMAC), nil + } +} diff --git a/go.mod b/go.mod index 88ed1a1..a277cf0 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/sagernet/sing-shadowtls go 1.18 require ( + github.com/sagernet/sing v0.1.6 golang.org/x/crypto v0.6.0 golang.org/x/sys v0.5.0 ) diff --git a/go.sum b/go.sum index 1ee4ca9..6e09fb0 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/sagernet/sing v0.1.6 h1:Qy63OUfKpcqKjfd5rPmUlj0RGjHZSK/PJn0duyCCsRg= +github.com/sagernet/sing v0.1.6/go.mod h1:JLSXsPTGRJFo/3X7EcAOCUgJH2/gAoxSJgBsnCZRp/w= golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= diff --git a/service.go b/service.go new file mode 100644 index 0000000..9431fe2 --- /dev/null +++ b/service.go @@ -0,0 +1,196 @@ +package shadowtls + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "encoding/hex" + "net" + "os" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/debug" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/task" +) + +type ServiceConfig struct { + Version int + Password string + HandshakeServer M.Socksaddr + HandshakeDialer N.Dialer + Handler Handler + Logger logger.ContextLogger +} + +type Handler interface { + N.TCPConnectionHandler + E.Handler +} + +type Service struct { + version int + password string + handshakeServer M.Socksaddr + handshakeDialer N.Dialer + handler Handler + logger logger.ContextLogger +} + +func NewService(config ServiceConfig) (*Service, error) { + service := &Service{ + version: config.Version, + password: config.Password, + handshakeServer: config.HandshakeServer, + handshakeDialer: config.HandshakeDialer, + handler: config.Handler, + logger: config.Logger, + } + if !service.handshakeServer.IsValid() || service.handler == nil || service.logger == nil { + return nil, os.ErrInvalid + } + switch config.Version { + case 1, 2, 3: + default: + return nil, E.New("unknown protocol version: ", config.Version) + } + + return service, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeServer) + if err != nil { + return E.Cause(err, "server handshake") + } + switch s.version { + default: + fallthrough + case 1: + var group task.Group + group.Append("client handshake", func(ctx context.Context) error { + return copyUntilHandshakeFinished(handshakeConn, conn) + }) + group.Append("server handshake", func(ctx context.Context) error { + return copyUntilHandshakeFinished(conn, handshakeConn) + }) + group.FastFail() + group.Cleanup(func() { + handshakeConn.Close() + }) + err = group.Run(ctx) + if err != nil { + return err + } + s.logger.TraceContext(ctx, "handshake finished") + return s.handler.NewConnection(ctx, conn, metadata) + case 2: + hashConn := newHashWriteConn(conn, s.password) + go bufio.Copy(hashConn, handshakeConn) + var request *buf.Buffer + request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, conn, hashConn, 2) + if err == nil { + s.logger.TraceContext(ctx, "handshake finished") + handshakeConn.Close() + return s.handler.NewConnection(ctx, bufio.NewCachedConn(newConn(conn), request), metadata) + } else if err == os.ErrPermission { + s.logger.WarnContext(ctx, "fallback connection") + hashConn.Fallback() + return common.Error(bufio.Copy(handshakeConn, conn)) + } else { + return err + } + case 3: + var clientHelloFrame *buf.Buffer + clientHelloFrame, err = extractFrame(conn) + if err != nil { + return E.Cause(err, "read client handshake") + } + _, err = handshakeConn.Write(clientHelloFrame.Bytes()) + if err != nil { + clientHelloFrame.Release() + return E.Cause(err, "write client handshake") + } + err = verifyClientHello(clientHelloFrame.Bytes(), s.password) + if err != nil { + s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed")) + return bufio.CopyConn(ctx, conn, handshakeConn) + } + s.logger.TraceContext(ctx, "client hello verify success") + clientHelloFrame.Release() + + var serverHelloFrame *buf.Buffer + serverHelloFrame, err = extractFrame(handshakeConn) + if err != nil { + return E.Cause(err, "read server handshake") + } + + _, err = conn.Write(serverHelloFrame.Bytes()) + if err != nil { + serverHelloFrame.Release() + return E.Cause(err, "write server handshake") + } + + serverRandom := extractServerRandom(serverHelloFrame.Bytes()) + + if serverRandom == nil { + s.logger.WarnContext(ctx, "server random extract failed, will copy bidirectional") + return bufio.CopyConn(ctx, conn, handshakeConn) + } + + if !isServerHelloSupportTLS13(serverHelloFrame.Bytes()) { + s.logger.WarnContext(ctx, "TLS 1.3 is not supported, will copy bidirectional") + return bufio.CopyConn(ctx, conn, handshakeConn) + } + + serverHelloFrame.Release() + if debug.Enabled { + s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom)) + } + hmacWrite := hmac.New(sha1.New, []byte(s.password)) + hmacWrite.Write(serverRandom) + hmacAdd := hmac.New(sha1.New, []byte(s.password)) + hmacAdd.Write(serverRandom) + hmacAdd.Write([]byte("S")) + hmacVerify := hmac.New(sha1.New, []byte(s.password)) + hmacVerifyReset := func() { + hmacVerify.Reset() + hmacVerify.Write(serverRandom) + hmacVerify.Write([]byte("C")) + } + + var clientFirstFrame *buf.Buffer + var group task.Group + var handshakeFinished bool + group.Append("client handshake relay", func(ctx context.Context) error { + clientFrame, cErr := copyByFrameUntilHMACMatches(conn, handshakeConn, hmacVerify, hmacVerifyReset) + if cErr == nil { + clientFirstFrame = clientFrame + handshakeFinished = true + handshakeConn.Close() + } + return cErr + }) + group.Append("server handshake relay", func(ctx context.Context) error { + cErr := copyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite) + if E.IsClosedOrCanceled(cErr) && handshakeFinished { + return nil + } + return cErr + }) + group.Cleanup(func() { + handshakeConn.Close() + }) + err = group.Run(ctx) + if err != nil { + return E.Cause(err, "handshake relay") + } + s.logger.TraceContext(ctx, "handshake relay finished") + return s.handler.NewConnection(ctx, bufio.NewCachedConn(newVerifiedConn(conn, hmacAdd, hmacVerify, nil), clientFirstFrame), metadata) + } +} diff --git a/tls.go b/tls.go new file mode 100644 index 0000000..851b4b0 --- /dev/null +++ b/tls.go @@ -0,0 +1,20 @@ +//go:build go1.20 + +package shadowtls + +import ( + sTLS "github.com/sagernet/sing-shadowtls/tls" +) + +type ( + sTLSConfig = sTLS.Config + sTLSConnectionState = sTLS.ConnectionState + sTLSConn = sTLS.Conn + sTLSCurveID = sTLS.CurveID + sTLSRenegotiationSupport = sTLS.RenegotiationSupport +) + +var ( + sTLSCipherSuites = sTLS.CipherSuites + sTLSClient = sTLS.Client +) diff --git a/tls_compact.go b/tls_compact.go new file mode 100644 index 0000000..e1b10b7 --- /dev/null +++ b/tls_compact.go @@ -0,0 +1,18 @@ +//go:build !go1.20 + +package shadowtls + +import sTLS "github.com/sagernet/sing-shadowtls/tls_compact" + +type ( + sTLSConfig = sTLS.Config + sTLSConnectionState = sTLS.ConnectionState + sTLSConn = sTLS.Conn + sTLSCurveID = sTLS.CurveID + sTLSRenegotiationSupport = sTLS.RenegotiationSupport +) + +var ( + sTLSCipherSuites = sTLS.CipherSuites + sTLSClient = sTLS.Client +) diff --git a/tls_wrapper.go b/tls_wrapper.go new file mode 100644 index 0000000..484347d --- /dev/null +++ b/tls_wrapper.go @@ -0,0 +1,44 @@ +package shadowtls + +import ( + "context" + "crypto/tls" + "net" + + "github.com/sagernet/sing/common" +) + +type ( + TLSSessionIDGeneratorFunc func(clientHello []byte, sessionID []byte) error + + TLSHandshakeFunc func( + ctx context.Context, + conn net.Conn, + sessionIDGenerator TLSSessionIDGeneratorFunc, // for shadow-tls version 3 + ) error +) + +func DefaultTLSHandshakeFunc(password string, config *tls.Config) TLSHandshakeFunc { + return func(ctx context.Context, conn net.Conn, sessionIDGenerator TLSSessionIDGeneratorFunc) error { + tlsConfig := &sTLSConfig{ + Rand: config.Rand, + Time: config.Time, + VerifyPeerCertificate: config.VerifyPeerCertificate, + RootCAs: config.RootCAs, + NextProtos: config.NextProtos, + ServerName: config.ServerName, + InsecureSkipVerify: config.InsecureSkipVerify, + CipherSuites: config.CipherSuites, + MinVersion: config.MinVersion, + MaxVersion: config.MaxVersion, + CurvePreferences: common.Map(config.CurvePreferences, func(it tls.CurveID) sTLSCurveID { + return sTLSCurveID(it) + }), + SessionTicketsDisabled: config.SessionTicketsDisabled, + Renegotiation: sTLSRenegotiationSupport(config.Renegotiation), + SessionIDGenerator: generateSessionID(password), + } + tlsConn := sTLSClient(conn, tlsConfig) + return tlsConn.HandshakeContext(ctx) + } +} diff --git a/v1_server.go b/v1_server.go new file mode 100644 index 0000000..9485316 --- /dev/null +++ b/v1_server.go @@ -0,0 +1,37 @@ +package shadowtls + +import ( + "bytes" + "encoding/binary" + "io" + + E "github.com/sagernet/sing/common/exceptions" +) + +func copyUntilHandshakeFinished(dst io.Writer, src io.Reader) error { + var hasSeenChangeCipherSpec bool + var tlsHdr [tlsHeaderSize]byte + for { + _, err := io.ReadFull(src, tlsHdr[:]) + if err != nil { + return err + } + length := binary.BigEndian.Uint16(tlsHdr[3:]) + _, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), io.LimitReader(src, int64(length)))) + if err != nil { + return err + } + if tlsHdr[0] != handshake { + if tlsHdr[0] != changeCipherSpec { + return E.New("unexpected tls frame type: ", tlsHdr[0]) + } + if !hasSeenChangeCipherSpec { + hasSeenChangeCipherSpec = true + continue + } + } + if hasSeenChangeCipherSpec { + return nil + } + } +} diff --git a/v2_client.go b/v2_client.go new file mode 100644 index 0000000..a600cc4 --- /dev/null +++ b/v2_client.go @@ -0,0 +1,37 @@ +package shadowtls + +import ( + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" +) + +type clientConn struct { + *shadowConn + hashConn *hashReadConn +} + +func newClientConn(hashConn *hashReadConn) *clientConn { + return &clientConn{newConn(hashConn.Conn), hashConn} +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if c.hashConn != nil { + sum := c.hashConn.Sum() + c.hashConn = nil + _, err = bufio.WriteVectorised(c.shadowConn, [][]byte{sum, p}) + if err == nil { + n = len(p) + } + return + } + return c.shadowConn.Write(p) +} + +func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error { + if c.hashConn != nil { + sum := c.hashConn.Sum() + c.hashConn = nil + return c.shadowConn.WriteVectorised(append([]*buf.Buffer{buf.As(sum)}, buffers...)) + } + return c.shadowConn.WriteVectorised(buffers) +} diff --git a/v2_conn.go b/v2_conn.go new file mode 100644 index 0000000..02ff3d0 --- /dev/null +++ b/v2_conn.go @@ -0,0 +1,95 @@ +package shadowtls + +import ( + "crypto/tls" + "encoding/binary" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +type shadowConn struct { + net.Conn + writer N.VectorisedWriter + readRemaining int +} + +func newConn(conn net.Conn) *shadowConn { + return &shadowConn{ + Conn: conn, + writer: bufio.NewVectorisedWriter(conn), + } +} + +func (c *shadowConn) Read(p []byte) (n int, err error) { + if c.readRemaining > 0 { + if len(p) > c.readRemaining { + p = p[:c.readRemaining] + } + n, err = c.Conn.Read(p) + c.readRemaining -= n + return + } + var tlsHeader [5]byte + _, err = io.ReadFull(c.Conn, common.Dup(tlsHeader[:])) + if err != nil { + return + } + length := int(binary.BigEndian.Uint16(tlsHeader[3:5])) + if tlsHeader[0] != 23 { + return 0, E.New("unexpected TLS record type: ", tlsHeader[0]) + } + readLen := len(p) + if readLen > length { + readLen = length + } + n, err = c.Conn.Read(p[:readLen]) + if err != nil { + return + } + c.readRemaining = length - n + return +} + +func (c *shadowConn) Write(p []byte) (n int, err error) { + var header [tlsHeaderSize]byte + defer common.KeepAlive(header) + header[0] = 23 + for len(p) > 16384 { + binary.BigEndian.PutUint16(header[1:3], tls.VersionTLS12) + binary.BigEndian.PutUint16(header[3:5], uint16(16384)) + _, err = bufio.WriteVectorised(c.writer, [][]byte{common.Dup(header[:]), p[:16384]}) + common.KeepAlive(header) + if err != nil { + return + } + n += 16384 + p = p[16384:] + } + binary.BigEndian.PutUint16(header[1:3], tls.VersionTLS12) + binary.BigEndian.PutUint16(header[3:5], uint16(len(p))) + _, err = bufio.WriteVectorised(c.writer, [][]byte{common.Dup(header[:]), p}) + if err == nil { + n += len(p) + } + return +} + +func (c *shadowConn) WriteVectorised(buffers []*buf.Buffer) error { + var header [tlsHeaderSize]byte + defer common.KeepAlive(header) + header[0] = 23 + dataLen := buf.LenMulti(buffers) + binary.BigEndian.PutUint16(header[1:3], tls.VersionTLS12) + binary.BigEndian.PutUint16(header[3:5], uint16(dataLen)) + return c.writer.WriteVectorised(append([]*buf.Buffer{buf.As(header[:])}, buffers...)) +} + +func (c *shadowConn) Upstream() any { + return c.Conn +} diff --git a/v2_hash.go b/v2_hash.go new file mode 100644 index 0000000..d8b2d7a --- /dev/null +++ b/v2_hash.go @@ -0,0 +1,74 @@ +package shadowtls + +import ( + "crypto/hmac" + "crypto/sha1" + "hash" + "net" +) + +type hashReadConn struct { + net.Conn + hmac hash.Hash +} + +func newHashReadConn(conn net.Conn, password string) *hashReadConn { + return &hashReadConn{ + conn, + hmac.New(sha1.New, []byte(password)), + } +} + +func (c *hashReadConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + return + } + _, err = c.hmac.Write(b[:n]) + return +} + +func (c *hashReadConn) Sum() []byte { + return c.hmac.Sum(nil)[:8] +} + +type hashWriteConn struct { + net.Conn + hmac hash.Hash + hasContent bool + lastSum []byte +} + +func newHashWriteConn(conn net.Conn, password string) *hashWriteConn { + return &hashWriteConn{ + Conn: conn, + hmac: hmac.New(sha1.New, []byte(password)), + } +} + +func (c *hashWriteConn) Write(p []byte) (n int, err error) { + if c.hmac != nil { + if c.hasContent { + c.lastSum = c.Sum() + } + c.hmac.Write(p) + c.hasContent = true + } + return c.Conn.Write(p) +} + +func (c *hashWriteConn) Sum() []byte { + return c.hmac.Sum(nil)[:8] +} + +func (c *hashWriteConn) LastSum() []byte { + return c.lastSum +} + +func (c *hashWriteConn) Fallback() { + c.hmac = nil +} + +func (c *hashWriteConn) HasContent() bool { + return c.hasContent +} diff --git a/v2_server.go b/v2_server.go new file mode 100644 index 0000000..6f941fa --- /dev/null +++ b/v2_server.go @@ -0,0 +1,58 @@ +package shadowtls + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net" + "os" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/logger" +) + +func copyUntilHandshakeFinishedV2(ctx context.Context, logger logger.ContextLogger, dst net.Conn, src io.Reader, hash *hashWriteConn, fallbackAfter int) (*buf.Buffer, error) { + var tlsHdr [tlsHeaderSize]byte + var applicationDataCount int + for { + _, err := io.ReadFull(src, tlsHdr[:]) + if err != nil { + return nil, err + } + length := binary.BigEndian.Uint16(tlsHdr[3:]) + if tlsHdr[0] == applicationData { + data := buf.NewSize(int(length)) + _, err = data.ReadFullFrom(src, int(length)) + if err != nil { + data.Release() + return nil, err + } + if hash.HasContent() && length >= 8 { + checksum := hash.Sum() + if bytes.Equal(data.To(8), checksum) { + logger.TraceContext(ctx, "match current hashcode") + data.Advance(8) + return data, nil + } else if hash.LastSum() != nil && bytes.Equal(data.To(8), hash.LastSum()) { + logger.TraceContext(ctx, "match last hashcode") + data.Advance(8) + return data, nil + } else { + logger.TraceContext(ctx, "hashcode mismatch") + } + } + _, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data)) + data.Release() + applicationDataCount++ + } else { + _, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), io.LimitReader(src, int64(length)))) + } + if err != nil { + return nil, err + } + if applicationDataCount > fallbackAfter { + return nil, os.ErrPermission + } + } +} diff --git a/v3_client.go b/v3_client.go new file mode 100644 index 0000000..0976f24 --- /dev/null +++ b/v3_client.go @@ -0,0 +1,115 @@ +package shadowtls + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "crypto/sha256" + "encoding/binary" + "hash" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" +) + +func generateSessionID(password string) func(clientHello []byte, sessionID []byte) error { + return func(clientHello []byte, sessionID []byte) error { + const sessionIDStart = 1 + 3 + 2 + tlsRandomSize + 1 + if len(clientHello) < sessionIDStart+tlsSessionIDSize { + return E.New("unexpected client hello length") + } + _, err := rand.Read(sessionID[:tlsSessionIDSize-hmacSize]) + if err != nil { + return err + } + hmacSHA1Hash := hmac.New(sha1.New, []byte(password)) + hmacSHA1Hash.Write(clientHello[:sessionIDStart]) + hmacSHA1Hash.Write(sessionID) + hmacSHA1Hash.Write(clientHello[sessionIDStart+tlsSessionIDSize:]) + copy(sessionID[tlsSessionIDSize-hmacSize:], hmacSHA1Hash.Sum(nil)[:hmacSize]) + return nil + } +} + +type streamWrapper struct { + net.Conn + password string + buffer *buf.Buffer + serverRandom []byte + readHMAC hash.Hash + readHMACKey []byte + authorized bool +} + +func newStreamWrapper(conn net.Conn, password string) *streamWrapper { + return &streamWrapper{ + Conn: conn, + password: password, + } +} + +func (w *streamWrapper) Authorized() (bool, []byte, hash.Hash) { + return w.authorized, w.serverRandom, w.readHMAC +} + +func (w *streamWrapper) Read(p []byte) (n int, err error) { + if w.buffer != nil { + if !w.buffer.IsEmpty() { + return w.buffer.Read(p) + } + w.buffer.Release() + w.buffer = nil + } + var tlsHeader [tlsHeaderSize]byte + _, err = io.ReadFull(w.Conn, tlsHeader[:]) + if err != nil { + return + } + length := int(binary.BigEndian.Uint16(tlsHeader[3:tlsHeaderSize])) + w.buffer = buf.NewSize(tlsHeaderSize + length) + common.Must1(w.buffer.Write(tlsHeader[:])) + _, err = w.buffer.ReadFullFrom(w.Conn, length) + if err != nil { + return + } + buffer := w.buffer.Bytes() + switch tlsHeader[0] { + case handshake: + if len(buffer) > serverRandomIndex+tlsRandomSize && buffer[5] == serverHello { + w.serverRandom = make([]byte, tlsRandomSize) + copy(w.serverRandom, buffer[serverRandomIndex:serverRandomIndex+tlsRandomSize]) + w.readHMAC = hmac.New(sha1.New, []byte(w.password)) + w.readHMAC.Write(w.serverRandom) + w.readHMACKey = kdf(w.password, w.serverRandom) + } + case applicationData: + w.authorized = false + if len(buffer) > tlsHmacHeaderSize && w.readHMAC != nil { + w.readHMAC.Write(buffer[tlsHmacHeaderSize:]) + if hmac.Equal(w.readHMAC.Sum(nil)[:hmacSize], buffer[tlsHeaderSize:tlsHmacHeaderSize]) { + xorSlice(buffer[tlsHmacHeaderSize:], w.readHMACKey) + copy(buffer[hmacSize:], buffer[:tlsHeaderSize]) + binary.BigEndian.PutUint16(buffer[hmacSize+3:], uint16(len(buffer)-tlsHmacHeaderSize)) + w.buffer.Advance(hmacSize) + w.authorized = true + } + } + } + return w.buffer.Read(p) +} + +func kdf(password string, serverRandom []byte) []byte { + hasher := sha256.New() + hasher.Write([]byte(password)) + hasher.Write(serverRandom) + return hasher.Sum(nil) +} + +func xorSlice(data []byte, key []byte) { + for i := range data { + data[i] ^= key[i%len(key)] + } +} diff --git a/v3_conn.go b/v3_conn.go new file mode 100644 index 0000000..2d2d408 --- /dev/null +++ b/v3_conn.go @@ -0,0 +1,171 @@ +package shadowtls + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "hash" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +type verifiedConn struct { + net.Conn + writer N.VectorisedWriter + hmacAdd hash.Hash + hmacVerify hash.Hash + hmacIgnore hash.Hash + + buffer *buf.Buffer +} + +func newVerifiedConn( + conn net.Conn, + hmacAdd hash.Hash, + hmacVerify hash.Hash, + hmacIgnore hash.Hash, +) *verifiedConn { + return &verifiedConn{ + Conn: conn, + writer: bufio.NewVectorisedWriter(conn), + hmacAdd: hmacAdd, + hmacVerify: hmacVerify, + hmacIgnore: hmacIgnore, + } +} + +func (c *verifiedConn) Read(b []byte) (n int, err error) { + if c.buffer != nil { + if !c.buffer.IsEmpty() { + return c.buffer.Read(b) + } + c.buffer.Release() + c.buffer = nil + } + for { + var tlsHeader [tlsHeaderSize]byte + _, err = io.ReadFull(c.Conn, tlsHeader[:]) + if err != nil { + sendAlert(c.Conn) + return + } + length := int(binary.BigEndian.Uint16(tlsHeader[3:tlsHeaderSize])) + c.buffer = buf.NewSize(tlsHeaderSize + length) + common.Must1(c.buffer.Write(tlsHeader[:])) + _, err = c.buffer.ReadFullFrom(c.Conn, length) + if err != nil { + return + } + buffer := c.buffer.Bytes() + switch buffer[0] { + case alert: + err = E.Cause(net.ErrClosed, "remote alert") + return + case applicationData: + if c.hmacIgnore != nil { + if verifyApplicationData(buffer, c.hmacIgnore, false) { + c.buffer.Release() + c.buffer = nil + continue + } else { + c.hmacIgnore = nil + } + } + if !verifyApplicationData(buffer, c.hmacVerify, true) { + sendAlert(c.Conn) + err = E.New("application data verification failed") + return + } + c.buffer.Advance(tlsHmacHeaderSize) + default: + sendAlert(c.Conn) + err = E.New("unexpected TLS record type: ", buffer[0]) + return + } + return c.buffer.Read(b) + } +} + +func (c *verifiedConn) Write(p []byte) (n int, err error) { + pTotal := len(p) + for len(p) > 0 { + var pWrite []byte + if len(p) > 16384 { + pWrite = p[:16384] + p = p[16384:] + } else { + pWrite = p + p = nil + } + _, err = c.write(pWrite) + } + if err == nil { + n = pTotal + } + return +} + +func (c *verifiedConn) write(p []byte) (n int, err error) { + var header [tlsHmacHeaderSize]byte + header[0] = applicationData + header[1] = 3 + header[2] = 3 + binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(len(p))) + c.hmacAdd.Write(p) + hmacHash := c.hmacAdd.Sum(nil)[:hmacSize] + c.hmacAdd.Write(hmacHash) + copy(header[tlsHeaderSize:], hmacHash) + _, err = bufio.WriteVectorised(c.writer, [][]byte{common.Dup(header[:]), p}) + if err == nil { + n = len(p) + } + return +} + +func (c *verifiedConn) WriteVectorised(buffers []*buf.Buffer) error { + var header [tlsHmacHeaderSize]byte + header[0] = applicationData + header[1] = 3 + header[2] = 3 + binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(buf.LenMulti(buffers))) + for _, buffer := range buffers { + c.hmacAdd.Write(buffer.Bytes()) + } + c.hmacAdd.Write(c.hmacAdd.Sum(nil)[:hmacSize]) + copy(header[tlsHeaderSize:], c.hmacAdd.Sum(nil)[:hmacSize]) + return c.writer.WriteVectorised(append([]*buf.Buffer{buf.As(header[:])}, buffers...)) +} + +func verifyApplicationData(frame []byte, hmac hash.Hash, update bool) bool { + if frame[1] != 3 || frame[2] != 3 || len(frame) < tlsHmacHeaderSize { + return false + } + hmac.Write(frame[tlsHmacHeaderSize:]) + hmacHash := hmac.Sum(nil)[:hmacSize] + if update { + hmac.Write(hmacHash) + } + return bytes.Equal(frame[tlsHeaderSize:tlsHeaderSize+hmacSize], hmacHash) +} + +func sendAlert(writer io.Writer) { + const recordSize = 31 + record := [recordSize]byte{ + alert, + 3, + 3, + 0, + recordSize - tlsHeaderSize, + } + _, err := rand.Read(record[tlsHeaderSize:]) + if err != nil { + return + } + writer.Write(record[:]) +} diff --git a/v3_constrat.go b/v3_constrat.go new file mode 100644 index 0000000..a7fdbff --- /dev/null +++ b/v3_constrat.go @@ -0,0 +1,20 @@ +package shadowtls + +const ( + tlsRandomSize = 32 + tlsHeaderSize = 5 + tlsSessionIDSize = 32 + + clientHello = 1 + serverHello = 2 + + changeCipherSpec = 20 + alert = 21 + handshake = 22 + applicationData = 23 + + serverRandomIndex = tlsHeaderSize + 1 + 3 + 2 + sessionIDLengthIndex = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + tlsHmacHeaderSize = tlsHeaderSize + hmacSize + hmacSize = 4 +) diff --git a/v3_server.go b/v3_server.go new file mode 100644 index 0000000..8dbf95b --- /dev/null +++ b/v3_server.go @@ -0,0 +1,181 @@ +package shadowtls + +import ( + "bytes" + "crypto/hmac" + "crypto/sha1" + "encoding/binary" + "hash" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" +) + +func extractFrame(conn net.Conn) (*buf.Buffer, error) { + var tlsHeader [tlsHeaderSize]byte + _, err := io.ReadFull(conn, tlsHeader[:]) + if err != nil { + return nil, err + } + length := int(binary.BigEndian.Uint16(tlsHeader[3:])) + buffer := buf.NewSize(tlsHeaderSize + length) + common.Must1(buffer.Write(tlsHeader[:])) + _, err = buffer.ReadFullFrom(conn, length) + if err != nil { + buffer.Release() + } + return buffer, err +} + +func verifyClientHello(frame []byte, password string) error { + const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize + const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize + if len(frame) < minLen { + return io.ErrUnexpectedEOF + } else if frame[0] != handshake { + return E.New("unexpected record type") + } else if frame[5] != clientHello { + return E.New("unexpected handshake type") + } else if frame[sessionIDLengthIndex] != tlsSessionIDSize { + return E.New("unexpected session id length") + } + hmacSHA1Hash := hmac.New(sha1.New, []byte(password)) + hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex]) + hmacSHA1Hash.Write(rw.ZeroBytes[:4]) + hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:]) + if !hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) { + return E.New("hmac mismatch") + } + return nil +} + +func extractServerRandom(frame []byte) []byte { + const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + + if len(frame) < minLen || frame[0] != handshake || frame[5] != serverHello { + return nil + } + + serverRandom := make([]byte, tlsRandomSize) + copy(serverRandom, frame[serverRandomIndex:serverRandomIndex+tlsRandomSize]) + return serverRandom +} + +func isServerHelloSupportTLS13(frame []byte) bool { + if len(frame) < sessionIDLengthIndex { + return false + } + + reader := bytes.NewReader(frame[sessionIDLengthIndex:]) + + var sessionIdLength uint8 + err := binary.Read(reader, binary.BigEndian, &sessionIdLength) + if err != nil { + return false + } + _, err = io.CopyN(io.Discard, reader, int64(sessionIdLength)) + if err != nil { + return false + } + + _, err = io.CopyN(io.Discard, reader, 3) + if err != nil { + return false + } + + var extensionListLength uint16 + err = binary.Read(reader, binary.BigEndian, &extensionListLength) + if err != nil { + return false + } + for i := uint16(0); i < extensionListLength; i++ { + var extensionType uint16 + err = binary.Read(reader, binary.BigEndian, &extensionType) + if err != nil { + return false + } + var extensionLength uint16 + err = binary.Read(reader, binary.BigEndian, &extensionLength) + if err != nil { + return false + } + if extensionType != 43 { + _, err = io.CopyN(io.Discard, reader, int64(extensionLength)) + if err != nil { + return false + } + continue + } + if extensionLength != 2 { + return false + } + var extensionValue uint16 + err = binary.Read(reader, binary.BigEndian, &extensionValue) + if err != nil { + return false + } + return extensionValue == 0x0304 + } + return false +} + +func copyByFrameUntilHMACMatches(conn net.Conn, handshakeConn net.Conn, hmacVerify hash.Hash, hmacReset func()) (*buf.Buffer, error) { + for { + frameBuffer, err := extractFrame(conn) + if err != nil { + return nil, E.Cause(err, "read client record") + } + frame := frameBuffer.Bytes() + if len(frame) > tlsHmacHeaderSize && frame[0] == applicationData { + hmacReset() + hmacVerify.Write(frame[tlsHmacHeaderSize:]) + hmacHash := hmacVerify.Sum(nil)[:4] + if bytes.Equal(hmacHash, frame[tlsHeaderSize:tlsHmacHeaderSize]) { + hmacReset() + hmacVerify.Write(frame[tlsHmacHeaderSize:]) + hmacVerify.Write(frame[tlsHeaderSize:tlsHmacHeaderSize]) + frameBuffer.Advance(tlsHmacHeaderSize) + return frameBuffer, nil + } + } + _, err = handshakeConn.Write(frame) + frameBuffer.Release() + if err != nil { + return nil, E.Cause(err, "write clint frame") + } + } +} + +func copyByFrameWithModification(conn net.Conn, handshakeConn net.Conn, password string, serverRandom []byte, hmacWrite hash.Hash) error { + writeKey := kdf(password, serverRandom) + writer := bufio.NewVectorisedWriter(handshakeConn) + for { + frameBuffer, err := extractFrame(conn) + if err != nil { + return E.Cause(err, "read server record") + } + frame := frameBuffer.Bytes() + if frame[0] == applicationData { + xorSlice(frame[tlsHeaderSize:], writeKey) + hmacWrite.Write(frame[tlsHeaderSize:]) + binary.BigEndian.PutUint16(frame[3:], uint16(len(frame)-tlsHeaderSize+hmacSize)) + hmacHash := hmacWrite.Sum(nil)[:4] + _, err = bufio.WriteVectorised(writer, [][]byte{frame[:tlsHeaderSize], hmacHash, frame[tlsHeaderSize:]}) + frameBuffer.Release() + if err != nil { + return E.Cause(err, "write modified server frame") + } + } else { + _, err = handshakeConn.Write(frame) + frameBuffer.Release() + if err != nil { + return E.Cause(err, "write server frame") + } + } + } +}