mirror of
https://github.com/SagerNet/sing-shadowtls.git
synced 2025-03-31 10:47:35 +03:00
198 lines
5.5 KiB
Go
198 lines
5.5 KiB
Go
package shadowtls
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha1"
|
|
"crypto/tls"
|
|
"encoding/binary"
|
|
"hash"
|
|
"io"
|
|
"net"
|
|
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/bufio"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
)
|
|
|
|
func extractFrame(conn net.Conn) (*buf.Buffer, error) {
|
|
var tlsHeader [tlsHeaderSize]byte
|
|
_, err := io.ReadFull(conn, tlsHeader[:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
length := int(binary.BigEndian.Uint16(tlsHeader[3:]))
|
|
buffer := buf.NewSize(tlsHeaderSize + length)
|
|
common.Must1(buffer.Write(tlsHeader[:]))
|
|
_, err = buffer.ReadFullFrom(conn, length)
|
|
if err != nil {
|
|
buffer.Release()
|
|
}
|
|
return buffer, err
|
|
}
|
|
|
|
func extractServerName(frame []byte) (string, error) {
|
|
var hello *tls.ClientHelloInfo
|
|
err := tls.Server(bufio.NewReadOnlyConn(bytes.NewReader(frame)), &tls.Config{
|
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
hello = argHello
|
|
return nil, nil
|
|
},
|
|
}).HandshakeContext(context.Background())
|
|
if hello != nil {
|
|
return hello.ServerName, nil
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
func verifyClientHello(frame []byte, users []User) (*User, error) {
|
|
const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize
|
|
const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize
|
|
if len(frame) < minLen {
|
|
return nil, io.ErrUnexpectedEOF
|
|
} else if frame[0] != handshake {
|
|
return nil, E.New("unexpected record type")
|
|
} else if frame[5] != clientHello {
|
|
return nil, E.New("unexpected handshake type")
|
|
} else if frame[sessionIDLengthIndex] != tlsSessionIDSize {
|
|
return nil, E.New("unexpected session id length")
|
|
}
|
|
for _, user := range users {
|
|
hmacSHA1Hash := hmac.New(sha1.New, []byte(user.Password))
|
|
hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex])
|
|
hmacSHA1Hash.Write([]byte{0, 0, 0, 0})
|
|
hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:])
|
|
if hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) {
|
|
return &user, nil
|
|
}
|
|
}
|
|
return nil, E.New("hmac mismatch")
|
|
}
|
|
|
|
func extractServerRandom(frame []byte) []byte {
|
|
const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize
|
|
|
|
if len(frame) < minLen || frame[0] != handshake || frame[5] != serverHello {
|
|
return nil
|
|
}
|
|
|
|
serverRandom := make([]byte, tlsRandomSize)
|
|
copy(serverRandom, frame[serverRandomIndex:serverRandomIndex+tlsRandomSize])
|
|
return serverRandom
|
|
}
|
|
|
|
func isServerHelloSupportTLS13(frame []byte) bool {
|
|
if len(frame) < sessionIDLengthIndex {
|
|
return false
|
|
}
|
|
|
|
reader := bytes.NewReader(frame[sessionIDLengthIndex:])
|
|
|
|
var sessionIdLength uint8
|
|
err := binary.Read(reader, binary.BigEndian, &sessionIdLength)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
_, err = io.CopyN(io.Discard, reader, int64(sessionIdLength))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
_, err = io.CopyN(io.Discard, reader, 3)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
var extensionListLength uint16
|
|
err = binary.Read(reader, binary.BigEndian, &extensionListLength)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for i := uint16(0); i < extensionListLength; i++ {
|
|
var extensionType uint16
|
|
err = binary.Read(reader, binary.BigEndian, &extensionType)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
var extensionLength uint16
|
|
err = binary.Read(reader, binary.BigEndian, &extensionLength)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
if extensionType != 43 {
|
|
_, err = io.CopyN(io.Discard, reader, int64(extensionLength))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
continue
|
|
}
|
|
if extensionLength != 2 {
|
|
return false
|
|
}
|
|
var extensionValue uint16
|
|
err = binary.Read(reader, binary.BigEndian, &extensionValue)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return extensionValue == 0x0304
|
|
}
|
|
return false
|
|
}
|
|
|
|
func copyByFrameUntilHMACMatches(conn net.Conn, handshakeConn net.Conn, hmacVerify hash.Hash, hmacReset func()) (*buf.Buffer, error) {
|
|
for {
|
|
frameBuffer, err := extractFrame(conn)
|
|
if err != nil {
|
|
return nil, E.Cause(err, "read client record")
|
|
}
|
|
frame := frameBuffer.Bytes()
|
|
if len(frame) > tlsHmacHeaderSize && frame[0] == applicationData {
|
|
hmacReset()
|
|
hmacVerify.Write(frame[tlsHmacHeaderSize:])
|
|
hmacHash := hmacVerify.Sum(nil)[:4]
|
|
if bytes.Equal(hmacHash, frame[tlsHeaderSize:tlsHmacHeaderSize]) {
|
|
hmacReset()
|
|
hmacVerify.Write(frame[tlsHmacHeaderSize:])
|
|
hmacVerify.Write(frame[tlsHeaderSize:tlsHmacHeaderSize])
|
|
frameBuffer.Advance(tlsHmacHeaderSize)
|
|
return frameBuffer, nil
|
|
}
|
|
}
|
|
_, err = handshakeConn.Write(frame)
|
|
frameBuffer.Release()
|
|
if err != nil {
|
|
return nil, E.Cause(err, "write clint frame")
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyByFrameWithModification(conn net.Conn, handshakeConn net.Conn, password string, serverRandom []byte, hmacWrite hash.Hash) error {
|
|
writeKey := kdf(password, serverRandom)
|
|
writer := bufio.NewVectorisedWriter(handshakeConn)
|
|
for {
|
|
frameBuffer, err := extractFrame(conn)
|
|
if err != nil {
|
|
return E.Cause(err, "read server record")
|
|
}
|
|
frame := frameBuffer.Bytes()
|
|
if frame[0] == applicationData {
|
|
xorSlice(frame[tlsHeaderSize:], writeKey)
|
|
hmacWrite.Write(frame[tlsHeaderSize:])
|
|
binary.BigEndian.PutUint16(frame[3:], uint16(len(frame)-tlsHeaderSize+hmacSize))
|
|
hmacHash := hmacWrite.Sum(nil)[:4]
|
|
_, err = bufio.WriteVectorised(writer, [][]byte{frame[:tlsHeaderSize], hmacHash, frame[tlsHeaderSize:]})
|
|
frameBuffer.Release()
|
|
if err != nil {
|
|
return E.Cause(err, "write modified server frame")
|
|
}
|
|
} else {
|
|
_, err = handshakeConn.Write(frame)
|
|
frameBuffer.Release()
|
|
if err != nil {
|
|
return E.Cause(err, "write server frame")
|
|
}
|
|
}
|
|
}
|
|
}
|