sing-shadowtls/service.go
2023-02-20 13:51:43 +08:00

196 lines
5.7 KiB
Go

package shadowtls
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/hex"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
)
type ServiceConfig struct {
Version int
Password string
HandshakeServer M.Socksaddr
HandshakeDialer N.Dialer
Handler Handler
Logger logger.ContextLogger
}
type Handler interface {
N.TCPConnectionHandler
E.Handler
}
type Service struct {
version int
password string
handshakeServer M.Socksaddr
handshakeDialer N.Dialer
handler Handler
logger logger.ContextLogger
}
func NewService(config ServiceConfig) (*Service, error) {
service := &Service{
version: config.Version,
password: config.Password,
handshakeServer: config.HandshakeServer,
handshakeDialer: config.HandshakeDialer,
handler: config.Handler,
logger: config.Logger,
}
if !service.handshakeServer.IsValid() || service.handler == nil || service.logger == nil {
return nil, os.ErrInvalid
}
switch config.Version {
case 1, 2, 3:
default:
return nil, E.New("unknown protocol version: ", config.Version)
}
return service, nil
}
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeServer)
if err != nil {
return E.Cause(err, "server handshake")
}
switch s.version {
default:
fallthrough
case 1:
var group task.Group
group.Append("client handshake", func(ctx context.Context) error {
return copyUntilHandshakeFinished(handshakeConn, conn)
})
group.Append("server handshake", func(ctx context.Context) error {
return copyUntilHandshakeFinished(conn, handshakeConn)
})
group.FastFail()
group.Cleanup(func() {
handshakeConn.Close()
})
err = group.Run(ctx)
if err != nil {
return err
}
s.logger.TraceContext(ctx, "handshake finished")
return s.handler.NewConnection(ctx, conn, metadata)
case 2:
hashConn := newHashWriteConn(conn, s.password)
go bufio.Copy(hashConn, handshakeConn)
var request *buf.Buffer
request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, conn, hashConn, 2)
if err == nil {
s.logger.TraceContext(ctx, "handshake finished")
handshakeConn.Close()
return s.handler.NewConnection(ctx, bufio.NewCachedConn(newConn(conn), request), metadata)
} else if err == os.ErrPermission {
s.logger.WarnContext(ctx, "fallback connection")
hashConn.Fallback()
return common.Error(bufio.Copy(handshakeConn, conn))
} else {
return err
}
case 3:
var clientHelloFrame *buf.Buffer
clientHelloFrame, err = extractFrame(conn)
if err != nil {
return E.Cause(err, "read client handshake")
}
_, err = handshakeConn.Write(clientHelloFrame.Bytes())
if err != nil {
clientHelloFrame.Release()
return E.Cause(err, "write client handshake")
}
err = verifyClientHello(clientHelloFrame.Bytes(), s.password)
if err != nil {
s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed"))
return bufio.CopyConn(ctx, conn, handshakeConn)
}
s.logger.TraceContext(ctx, "client hello verify success")
clientHelloFrame.Release()
var serverHelloFrame *buf.Buffer
serverHelloFrame, err = extractFrame(handshakeConn)
if err != nil {
return E.Cause(err, "read server handshake")
}
_, err = conn.Write(serverHelloFrame.Bytes())
if err != nil {
serverHelloFrame.Release()
return E.Cause(err, "write server handshake")
}
serverRandom := extractServerRandom(serverHelloFrame.Bytes())
if serverRandom == nil {
s.logger.WarnContext(ctx, "server random extract failed, will copy bidirectional")
return bufio.CopyConn(ctx, conn, handshakeConn)
}
if !isServerHelloSupportTLS13(serverHelloFrame.Bytes()) {
s.logger.WarnContext(ctx, "TLS 1.3 is not supported, will copy bidirectional")
return bufio.CopyConn(ctx, conn, handshakeConn)
}
serverHelloFrame.Release()
if debug.Enabled {
s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom))
}
hmacWrite := hmac.New(sha1.New, []byte(s.password))
hmacWrite.Write(serverRandom)
hmacAdd := hmac.New(sha1.New, []byte(s.password))
hmacAdd.Write(serverRandom)
hmacAdd.Write([]byte("S"))
hmacVerify := hmac.New(sha1.New, []byte(s.password))
hmacVerifyReset := func() {
hmacVerify.Reset()
hmacVerify.Write(serverRandom)
hmacVerify.Write([]byte("C"))
}
var clientFirstFrame *buf.Buffer
var group task.Group
var handshakeFinished bool
group.Append("client handshake relay", func(ctx context.Context) error {
clientFrame, cErr := copyByFrameUntilHMACMatches(conn, handshakeConn, hmacVerify, hmacVerifyReset)
if cErr == nil {
clientFirstFrame = clientFrame
handshakeFinished = true
handshakeConn.Close()
}
return cErr
})
group.Append("server handshake relay", func(ctx context.Context) error {
cErr := copyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite)
if E.IsClosedOrCanceled(cErr) && handshakeFinished {
return nil
}
return cErr
})
group.Cleanup(func() {
handshakeConn.Close()
})
err = group.Run(ctx)
if err != nil {
return E.Cause(err, "handshake relay")
}
s.logger.TraceContext(ctx, "handshake relay finished")
return s.handler.NewConnection(ctx, bufio.NewCachedConn(newVerifiedConn(conn, hmacAdd, hmacVerify, nil), clientFirstFrame), metadata)
}
}