sing/protocol/socks/lazy.go
2024-12-10 19:53:57 +08:00

215 lines
4.6 KiB
Go

package socks
import (
"net"
"os"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
type LazyConn struct {
net.Conn
socksVersion byte
responseWritten bool
}
func NewLazyConn(conn net.Conn, socksVersion byte) *LazyConn {
return &LazyConn{
Conn: conn,
socksVersion: socksVersion,
}
}
func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error {
if c.responseWritten {
return nil
}
defer func() {
c.responseWritten = true
}()
switch c.socksVersion {
case socks4.Version:
return socks4.WriteResponse(c.Conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.SocksaddrFromNet(conn.LocalAddr()),
})
case socks5.Version:
return socks5.WriteResponse(c.Conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
})
default:
panic("unknown socks version")
}
}
func (c *LazyConn) HandshakeFailure(err error) error {
if c.responseWritten {
return os.ErrInvalid
}
defer func() {
c.responseWritten = true
}()
switch c.socksVersion {
case socks4.Version:
return socks4.WriteResponse(c.Conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
case socks5.Version:
return socks5.WriteResponse(c.Conn, socks5.Response{
ReplyCode: socks5.ReplyCodeForError(err),
})
default:
panic("unknown socks version")
}
}
func (c *LazyConn) Read(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.ConnHandshakeSuccess(c.Conn)
if err != nil {
return
}
}
return c.Conn.Read(p)
}
func (c *LazyConn) Write(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.ConnHandshakeSuccess(c.Conn)
if err != nil {
return
}
}
return c.Conn.Write(p)
}
func (c *LazyConn) ReaderReplaceable() bool {
return c.responseWritten
}
func (c *LazyConn) WriterReplaceable() bool {
return c.responseWritten
}
func (c *LazyConn) Upstream() any {
return c.Conn
}
type LazyAssociatePacketConn struct {
AssociatePacketConn
responseWritten bool
}
func NewLazyAssociatePacketConn(conn net.Conn, underlying net.Conn) *LazyAssociatePacketConn {
return &LazyAssociatePacketConn{
AssociatePacketConn: AssociatePacketConn{
AbstractConn: conn,
conn: bufio.NewExtendedConn(conn),
underlying: underlying,
},
}
}
func (c *LazyAssociatePacketConn) HandshakeSuccess() error {
if c.responseWritten {
return nil
}
defer func() {
c.responseWritten = true
}()
return socks5.WriteResponse(c.underlying, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(c.conn.LocalAddr()),
})
}
func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error {
if c.responseWritten {
return os.ErrInvalid
}
defer func() {
c.responseWritten = true
c.conn.Close()
c.underlying.Close()
}()
return socks5.WriteResponse(c.underlying, socks5.Response{
ReplyCode: socks5.ReplyCodeForError(err),
})
}
func (c *LazyAssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.ReadFrom(p)
}
func (c *LazyAssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.WriteTo(p, addr)
}
func (c *LazyAssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.ReadPacket(buffer)
}
func (c *LazyAssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.responseWritten {
err := c.HandshakeSuccess()
if err != nil {
return err
}
}
return c.AssociatePacketConn.WritePacket(buffer, destination)
}
func (c *LazyAssociatePacketConn) Read(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.Read(p)
}
func (c *LazyAssociatePacketConn) Write(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.Write(p)
}
func (c *LazyAssociatePacketConn) ReaderReplaceable() bool {
return c.responseWritten
}
func (c *LazyAssociatePacketConn) WriterReplaceable() bool {
return c.responseWritten
}
func (c *LazyAssociatePacketConn) Upstream() any {
return &c.AssociatePacketConn
}