sing-shadowtls/v3_conn.go
2024-11-17 14:50:55 +08:00

207 lines
4.8 KiB
Go

package shadowtls
import (
"bytes"
"crypto/rand"
"encoding/binary"
"hash"
"io"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
type verifiedConn struct {
net.Conn
writer N.ExtendedWriter
vectorisedWriter N.VectorisedWriter
access sync.Mutex
hmacAdd hash.Hash
hmacVerify hash.Hash
hmacIgnore hash.Hash
buffer *buf.Buffer
}
func newVerifiedConn(
conn net.Conn,
hmacAdd hash.Hash,
hmacVerify hash.Hash,
hmacIgnore hash.Hash,
) *verifiedConn {
return &verifiedConn{
Conn: conn,
writer: bufio.NewExtendedWriter(conn),
vectorisedWriter: bufio.NewVectorisedWriter(conn),
hmacAdd: hmacAdd,
hmacVerify: hmacVerify,
hmacIgnore: hmacIgnore,
}
}
func (c *verifiedConn) Read(b []byte) (n int, err error) {
if c.buffer != nil {
if !c.buffer.IsEmpty() {
return c.buffer.Read(b)
}
c.buffer.Release()
c.buffer = nil
}
for {
var tlsHeader [tlsHeaderSize]byte
_, err = io.ReadFull(c.Conn, tlsHeader[:])
if err != nil {
sendAlert(c.Conn)
return
}
length := int(binary.BigEndian.Uint16(tlsHeader[3:tlsHeaderSize]))
c.buffer = buf.NewSize(tlsHeaderSize + length)
common.Must1(c.buffer.Write(tlsHeader[:]))
_, err = c.buffer.ReadFullFrom(c.Conn, length)
if err != nil {
return
}
buffer := c.buffer.Bytes()
switch buffer[0] {
case alert:
err = E.Cause(net.ErrClosed, "remote alert")
return
case applicationData:
if c.hmacIgnore != nil {
if verifyApplicationData(buffer, c.hmacIgnore, false) {
c.buffer.Release()
c.buffer = nil
continue
} else {
c.hmacIgnore = nil
}
}
if !verifyApplicationData(buffer, c.hmacVerify, true) {
sendAlert(c.Conn)
err = E.New("application data verification failed")
return
}
c.buffer.Advance(tlsHmacHeaderSize)
default:
sendAlert(c.Conn)
err = E.New("unexpected TLS record type: ", buffer[0])
return
}
return c.buffer.Read(b)
}
}
func (c *verifiedConn) Write(p []byte) (n int, err error) {
pTotal := len(p)
for len(p) > 0 {
var pWrite []byte
if len(p) > 16384 {
pWrite = p[:16384]
p = p[16384:]
} else {
pWrite = p
p = nil
}
_, err = c.write(pWrite)
}
if err == nil {
n = pTotal
}
return
}
func (c *verifiedConn) write(p []byte) (n int, err error) {
var header [tlsHmacHeaderSize]byte
header[0] = applicationData
header[1] = 3
header[2] = 3
binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(len(p)))
c.access.Lock()
c.hmacAdd.Write(p)
hmacHash := c.hmacAdd.Sum(nil)[:hmacSize]
c.hmacAdd.Write(hmacHash)
c.access.Unlock()
copy(header[tlsHeaderSize:], hmacHash)
_, err = bufio.WriteVectorised(c.vectorisedWriter, [][]byte{header[:], p})
if err == nil {
n = len(p)
}
return
}
func (c *verifiedConn) WriteBuffer(buffer *buf.Buffer) error {
c.access.Lock()
c.hmacAdd.Write(buffer.Bytes())
dateLen := buffer.Len()
header := buffer.ExtendHeader(tlsHmacHeaderSize)
header[0] = applicationData
header[1] = 3
header[2] = 3
binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(dateLen))
hmacHash := c.hmacAdd.Sum(nil)[:hmacSize]
c.hmacAdd.Write(hmacHash)
c.access.Unlock()
copy(header[tlsHeaderSize:], hmacHash)
return c.writer.WriteBuffer(buffer)
}
func (c *verifiedConn) WriteVectorised(buffers []*buf.Buffer) error {
var header [tlsHmacHeaderSize]byte
header[0] = applicationData
header[1] = 3
header[2] = 3
binary.BigEndian.PutUint16(header[3:tlsHeaderSize], hmacSize+uint16(buf.LenMulti(buffers)))
c.access.Lock()
for _, buffer := range buffers {
c.hmacAdd.Write(buffer.Bytes())
}
c.hmacAdd.Write(c.hmacAdd.Sum(nil)[:hmacSize])
hmacHash := c.hmacAdd.Sum(nil)[:hmacSize]
c.access.Unlock()
copy(header[tlsHeaderSize:], hmacHash)
return c.vectorisedWriter.WriteVectorised(append([]*buf.Buffer{buf.As(header[:])}, buffers...))
}
func (c *verifiedConn) FrontHeadroom() int {
return tlsHmacHeaderSize
}
func (c *verifiedConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *verifiedConn) Upstream() any {
return c.Conn
}
func verifyApplicationData(frame []byte, hmac hash.Hash, update bool) bool {
if frame[1] != 3 || frame[2] != 3 || len(frame) < tlsHmacHeaderSize {
return false
}
hmac.Write(frame[tlsHmacHeaderSize:])
hmacHash := hmac.Sum(nil)[:hmacSize]
if update {
hmac.Write(hmacHash)
}
return bytes.Equal(frame[tlsHeaderSize:tlsHeaderSize+hmacSize], hmacHash)
}
func sendAlert(writer io.Writer) {
const recordSize = 31
record := [recordSize]byte{
alert,
3,
3,
0,
recordSize - tlsHeaderSize,
}
_, err := rand.Read(record[tlsHeaderSize:])
if err != nil {
return
}
writer.Write(record[:])
}