commit b8acf3f1454604212b926b854f2ba1224d565003 Author: 世界 Date: Sun Apr 23 17:05:39 2023 +0800 Init commit diff --git a/.github/update_dependencies.sh b/.github/update_dependencies.sh new file mode 100755 index 0000000..4702ddf --- /dev/null +++ b/.github/update_dependencies.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +PROJECTS=$(dirname "$0")/../.. +go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD) +go mod tidy diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml new file mode 100644 index 0000000..fcbf9cb --- /dev/null +++ b/.github/workflows/debug.yml @@ -0,0 +1,43 @@ +name: Debug build + +on: + push: + branches: + - main + paths-ignore: + - '**.md' + - '.github/**' + - '!.github/workflows/debug.yml' + pull_request: + branches: + - main + +jobs: + build: + name: Debug build + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Get latest go version + id: version + run: | + echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ${{ steps.version.outputs.go_version }} + - name: Add cache to Go proxy + run: | + version=`git rev-parse HEAD` + mkdir build + pushd build + go mod init build + go get -v github.com/sagernet/sing-mux@$version + popd + continue-on-error: true + - name: Build + run: | + make test diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..0a1c8f2 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,41 @@ +name: Lint + +on: + push: + branches: + - dev + paths-ignore: + - '**.md' + - '.github/**' + - '!.github/workflows/lint.yml' + pull_request: + branches: + - dev + +jobs: + build: + name: Build + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Get latest go version + id: version + run: | + echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g') + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ${{ steps.version.outputs.go_version }} + - name: Cache go module + uses: actions/cache@v3 + with: + path: | + ~/go/pkg/mod + key: go-${{ hashFiles('**/go.sum') }} + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f7f8ac3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/.idea/ +/vendor/ diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..4133a1d --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,17 @@ +linters: + disable-all: true + enable: + - gofumpt + - govet + - gci + - staticcheck + +linters-settings: + gci: + custom-order: true + sections: + - standard + - prefix(github.com/sagernet/) + - default + staticcheck: + go: '1.20' diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3e3e29e --- /dev/null +++ b/LICENSE @@ -0,0 +1,14 @@ +Copyright (C) 2022 by nekohasekai + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a837f60 --- /dev/null +++ b/Makefile @@ -0,0 +1,21 @@ +fmt: + @gofumpt -l -w . + @gofmt -s -w . + @gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" . + +fmt_install: + go install -v mvdan.cc/gofumpt@latest + go install -v github.com/daixiang0/gci@latest + +lint: + GOOS=linux golangci-lint run ./... + GOOS=android golangci-lint run ./... + GOOS=windows golangci-lint run ./... + GOOS=darwin golangci-lint run ./... + GOOS=freebsd golangci-lint run ./... + +lint_install: + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +test: + go test -v ./... \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..834fd7d --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# sing-mux + +Simple multiplex library. \ No newline at end of file diff --git a/client.go b/client.go new file mode 100644 index 0000000..eb9d1ea --- /dev/null +++ b/client.go @@ -0,0 +1,183 @@ +package mux + +import ( + "context" + "net" + "sync" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" +) + +type Client struct { + dialer N.Dialer + protocol byte + maxConnections int + minStreams int + maxStreams int + padding bool + access sync.Mutex + connections list.List[abstractSession] +} + +type Options struct { + Dialer N.Dialer + Protocol string + MaxConnections int + MinStreams int + MaxStreams int + Padding bool +} + +func NewClient(options Options) (*Client, error) { + client := &Client{ + dialer: options.Dialer, + maxConnections: options.MaxConnections, + minStreams: options.MinStreams, + maxStreams: options.MaxStreams, + padding: options.Padding, + } + if client.dialer == nil { + client.dialer = N.SystemDialer + } + if client.maxStreams == 0 && client.maxConnections == 0 { + client.minStreams = 8 + } + switch options.Protocol { + case "", "h2mux": + client.protocol = ProtocolH2Mux + case "smux": + client.protocol = ProtocolSmux + case "yamux": + client.protocol = ProtocolYAMux + default: + return nil, E.New("unknown protocol: " + options.Protocol) + } + return client, nil +} + +func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + stream, err := c.openStream(ctx) + if err != nil { + return nil, err + } + return &clientConn{Conn: stream, destination: destination}, nil + case N.NetworkUDP: + stream, err := c.openStream(ctx) + if err != nil { + return nil, err + } + return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil + default: + return nil, E.Extend(N.ErrUnknownNetwork, network) + } +} + +func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + stream, err := c.openStream(ctx) + if err != nil { + return nil, err + } + return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil +} + +func (c *Client) openStream(ctx context.Context) (net.Conn, error) { + var ( + session abstractSession + stream net.Conn + err error + ) + for attempts := 0; attempts < 2; attempts++ { + session, err = c.offer(ctx) + if err != nil { + continue + } + stream, err = session.Open() + if err != nil { + continue + } + break + } + if err != nil { + return nil, err + } + return &wrapStream{stream}, nil +} + +func (c *Client) offer(ctx context.Context) (abstractSession, error) { + c.access.Lock() + defer c.access.Unlock() + + sessions := make([]abstractSession, 0, c.maxConnections) + for element := c.connections.Front(); element != nil; { + if element.Value.IsClosed() { + nextElement := element.Next() + c.connections.Remove(element) + element = nextElement + continue + } + sessions = append(sessions, element.Value) + element = element.Next() + } + session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) + if session == nil { + return c.offerNew(ctx) + } + numStreams := session.NumStreams() + if numStreams == 0 { + return session, nil + } + if c.maxConnections > 0 { + if len(sessions) >= c.maxConnections || numStreams < c.minStreams { + return session, nil + } + } else { + if c.maxStreams > 0 && numStreams < c.maxStreams { + return session, nil + } + } + return c.offerNew(ctx) +} + +func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) + if err != nil { + return nil, err + } + var version byte + if c.padding { + version = Version1 + } else { + version = Version0 + } + conn = newProtocolConn(conn, Request{ + Version: version, + Protocol: c.protocol, + Padding: c.padding, + }) + if c.padding { + conn = newPaddingConn(conn) + } + session, err := newClientSession(conn, c.protocol) + if err != nil { + conn.Close() + return nil, err + } + c.connections.PushBack(session) + return session, nil +} + +func (c *Client) Reset() { + c.access.Lock() + defer c.access.Unlock() + for _, session := range c.connections.Array() { + session.Close() + } + c.connections.Init() +} diff --git a/client_conn.go b/client_conn.go new file mode 100644 index 0000000..7aba975 --- /dev/null +++ b/client_conn.go @@ -0,0 +1,380 @@ +package mux + +import ( + "encoding/binary" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type clientConn struct { + net.Conn + destination M.Socksaddr + requestWritten bool + responseRead bool +} + +func (c *clientConn) readResponse() error { + response, err := ReadStreamResponse(c.Conn) + if err != nil { + return err + } + if response.Status == statusError { + return E.New("remote error: ", response.Message) + } + return nil +} + +func (c *clientConn) Read(b []byte) (n int, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + return c.Conn.Read(b) +} + +func (c *clientConn) Write(b []byte) (n int, err error) { + if c.requestWritten { + return c.Conn.Write(b) + } + request := StreamRequest{ + Network: N.NetworkTCP, + Destination: c.destination, + } + _buffer := buf.StackNewSize(streamRequestLen(request) + len(b)) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + EncodeStreamRequest(request, buffer) + buffer.Write(b) + _, err = c.Conn.Write(buffer.Bytes()) + if err != nil { + return + } + c.requestWritten = true + return len(b), nil +} + +func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.requestWritten { + return bufio.ReadFrom0(c, r) + } + return bufio.Copy(c.Conn, r) +} + +func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { + if !c.responseRead { + return bufio.WriteTo0(c, w) + } + return bufio.Copy(w, c.Conn) +} + +func (c *clientConn) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c *clientConn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +func (c *clientConn) ReaderReplaceable() bool { + return c.responseRead +} + +func (c *clientConn) WriterReplaceable() bool { + return c.requestWritten +} + +func (c *clientConn) NeedAdditionalReadDeadline() bool { + return true +} + +func (c *clientConn) Upstream() any { + return c.Conn +} + +type clientPacketConn struct { + N.ExtendedConn + destination M.Socksaddr + requestWritten bool + responseRead bool +} + +func (c *clientPacketConn) readResponse() error { + response, err := ReadStreamResponse(c.ExtendedConn) + if err != nil { + return err + } + if response.Status == statusError { + return E.New("remote error: ", response.Message) + } + return nil +} + +func (c *clientPacketConn) Read(b []byte) (n int, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + if cap(b) < int(length) { + return 0, io.ErrShortBuffer + } + return io.ReadFull(c.ExtendedConn, b[:length]) +} + +func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { + request := StreamRequest{ + Network: N.NetworkUDP, + Destination: c.destination, + } + rLen := streamRequestLen(request) + if len(payload) > 0 { + rLen += 2 + len(payload) + } + _buffer := buf.StackNewSize(rLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + EncodeStreamRequest(request, buffer) + if len(payload) > 0 { + common.Must( + binary.Write(buffer, binary.BigEndian, uint16(len(payload))), + common.Error(buffer.Write(payload)), + ) + } + _, err = c.ExtendedConn.Write(buffer.Bytes()) + if err != nil { + return + } + c.requestWritten = true + return len(payload), nil +} + +func (c *clientPacketConn) Write(b []byte) (n int, err error) { + if !c.requestWritten { + return c.writeRequest(b) + } + err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b))) + if err != nil { + return + } + return c.ExtendedConn.Write(b) +} + +func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) + return +} + +func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error { + if !c.requestWritten { + defer buffer.Release() + return common.Error(c.writeRequest(buffer.Bytes())) + } + bLen := buffer.Len() + binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *clientPacketConn) FrontHeadroom() int { + return 2 +} + +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + err = c.ReadBuffer(buffer) + return +} + +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return c.WriteBuffer(buffer) +} + +func (c *clientPacketConn) LocalAddr() net.Addr { + return c.ExtendedConn.LocalAddr() +} + +func (c *clientPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + +func (c *clientPacketConn) NeedAdditionalReadDeadline() bool { + return true +} + +func (c *clientPacketConn) Upstream() any { + return c.ExtendedConn +} + +var _ N.NetPacketConn = (*clientPacketAddrConn)(nil) + +type clientPacketAddrConn struct { + N.ExtendedConn + destination M.Socksaddr + requestWritten bool + responseRead bool +} + +func (c *clientPacketAddrConn) readResponse() error { + response, err := ReadStreamResponse(c.ExtendedConn) + if err != nil { + return err + } + if response.Status == statusError { + return E.New("remote error: ", response.Message) + } + return nil +} + +func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) + if err != nil { + return + } + if destination.IsFqdn() { + addr = destination + } else { + addr = destination.UDPAddr() + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + if cap(p) < int(length) { + return 0, nil, io.ErrShortBuffer + } + n, err = io.ReadFull(c.ExtendedConn, p[:length]) + return +} + +func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) { + request := StreamRequest{ + Network: N.NetworkUDP, + Destination: c.destination, + PacketAddr: true, + } + rLen := streamRequestLen(request) + if len(payload) > 0 { + rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) + } + _buffer := buf.StackNewSize(rLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + EncodeStreamRequest(request, buffer) + if len(payload) > 0 { + common.Must( + M.SocksaddrSerializer.WriteAddrPort(buffer, destination), + binary.Write(buffer, binary.BigEndian, uint16(len(payload))), + common.Error(buffer.Write(payload)), + ) + } + _, err = c.ExtendedConn.Write(buffer.Bytes()) + if err != nil { + return + } + c.requestWritten = true + return len(payload), nil +} + +func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if !c.requestWritten { + return c.writeRequest(p, M.SocksaddrFromNet(addr)) + } + err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) + if err != nil { + return + } + err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) + if err != nil { + return + } + return c.ExtendedConn.Write(p) +} + +func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) + if err != nil { + return + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) + return +} + +func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if !c.requestWritten { + defer buffer.Release() + return common.Error(c.writeRequest(buffer.Bytes(), destination)) + } + bLen := buffer.Len() + header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) + common.Must( + M.SocksaddrSerializer.WriteAddrPort(header, destination), + binary.Write(header, binary.BigEndian, uint16(bLen)), + ) + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *clientPacketAddrConn) LocalAddr() net.Addr { + return c.ExtendedConn.LocalAddr() +} + +func (c *clientPacketAddrConn) FrontHeadroom() int { + return 2 + M.MaxSocksaddrLength +} + +func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool { + return true +} + +func (c *clientPacketAddrConn) Upstream() any { + return c.ExtendedConn +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..aa7ea8b --- /dev/null +++ b/error.go @@ -0,0 +1,37 @@ +package mux + +import ( + "io" + "net" + + "github.com/hashicorp/yamux" +) + +type wrapStream struct { + net.Conn +} + +func (w *wrapStream) Read(p []byte) (n int, err error) { + n, err = w.Conn.Read(p) + err = wrapError(err) + return +} + +func (w *wrapStream) Write(p []byte) (n int, err error) { + n, err = w.Conn.Write(p) + err = wrapError(err) + return +} + +func (w *wrapStream) Upstream() any { + return w.Conn +} + +func wrapError(err error) error { + switch err { + case yamux.ErrStreamClosed: + return io.EOF + default: + return err + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7404a9c --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/sagernet/sing-mux + +go 1.18 + +require ( + github.com/hashicorp/yamux v0.1.1 + github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 + github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 + golang.org/x/net v0.9.0 +) + +require ( + golang.org/x/sys v0.7.0 // indirect + golang.org/x/text v0.9.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..374667d --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= +github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= +github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= +github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss= +github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= +github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= +github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= diff --git a/h2mux.go b/h2mux.go new file mode 100644 index 0000000..2ba0b4a --- /dev/null +++ b/h2mux.go @@ -0,0 +1,242 @@ +package mux + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/url" + "os" + "time" + + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/net/http2" +) + +const idleTimeout = 30 * time.Second + +var _ abstractSession = (*h2MuxServerSession)(nil) + +type h2MuxServerSession struct { + server http2.Server + active atomic.Int32 + conn net.Conn + inbound chan net.Conn + done chan struct{} +} + +func newH2MuxServer(conn net.Conn) *h2MuxServerSession { + session := &h2MuxServerSession{ + conn: conn, + inbound: make(chan net.Conn), + done: make(chan struct{}), + server: http2.Server{ + IdleTimeout: idleTimeout, + }, + } + go func() { + session.server.ServeConn(conn, &http2.ServeConnOpts{ + Handler: session, + }) + _ = session.Close() + }() + return session +} + +func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + s.active.Add(1) + defer s.active.Add(-1) + writer.WriteHeader(http.StatusOK) + conn := newHTTP2Wrapper(newHTTPConn(request.Body, writer), writer.(http.Flusher)) + s.inbound <- conn + select { + case <-conn.done: + case <-s.done: + _ = conn.Close() + } +} + +func (s *h2MuxServerSession) Open() (net.Conn, error) { + return nil, os.ErrInvalid +} + +func (s *h2MuxServerSession) Accept() (net.Conn, error) { + select { + case conn := <-s.inbound: + return conn, nil + case <-s.done: + return nil, os.ErrClosed + } +} + +func (s *h2MuxServerSession) NumStreams() int { + return int(s.active.Load()) +} + +func (s *h2MuxServerSession) Close() error { + select { + case <-s.done: + default: + close(s.done) + } + return s.conn.Close() +} + +func (s *h2MuxServerSession) IsClosed() bool { + select { + case <-s.done: + return true + default: + return false + } +} + +func (s *h2MuxServerSession) CanTakeNewRequest() bool { + return false +} + +type h2MuxConnWrapper struct { + N.ExtendedConn + flusher http.Flusher + done chan struct{} +} + +func newHTTP2Wrapper(conn net.Conn, flusher http.Flusher) *h2MuxConnWrapper { + return &h2MuxConnWrapper{ + ExtendedConn: bufio.NewExtendedConn(conn), + flusher: flusher, + done: make(chan struct{}), + } +} + +func (w *h2MuxConnWrapper) Write(p []byte) (n int, err error) { + select { + case <-w.done: + return 0, net.ErrClosed + default: + } + n, err = w.ExtendedConn.Write(p) + if err == nil { + w.flusher.Flush() + } + return +} + +func (w *h2MuxConnWrapper) WriteBuffer(buffer *buf.Buffer) error { + select { + case <-w.done: + return net.ErrClosed + default: + } + err := w.ExtendedConn.WriteBuffer(buffer) + if err == nil { + w.flusher.Flush() + } + return err +} + +func (w *h2MuxConnWrapper) Close() error { + select { + case <-w.done: + default: + close(w.done) + } + return w.ExtendedConn.Close() +} + +func (w *h2MuxConnWrapper) Upstream() any { + return w.ExtendedConn +} + +var _ abstractSession = (*h2MuxClientSession)(nil) + +type h2MuxClientSession struct { + transport *http2.Transport + clientConn *http2.ClientConn + done chan struct{} +} + +func newH2MuxClient(conn net.Conn) (*h2MuxClientSession, error) { + session := &h2MuxClientSession{ + transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return conn, nil + }, + ReadIdleTimeout: idleTimeout, + }, + done: make(chan struct{}), + } + session.transport.ConnPool = session + clientConn, err := session.transport.NewClientConn(conn) + if err != nil { + return nil, err + } + session.clientConn = clientConn + return session, nil +} + +func (s *h2MuxClientSession) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) { + return s.clientConn, nil +} + +func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) { + s.Close() +} + +func (s *h2MuxClientSession) Open() (net.Conn, error) { + pipeInReader, pipeInWriter := io.Pipe() + request := &http.Request{ + Method: http.MethodConnect, + Body: pipeInReader, + URL: &url.URL{Scheme: "https", Host: "localhost"}, + } + conn := newLateHTTPConn(pipeInWriter) + go func() { + response, err := s.transport.RoundTrip(request) + if err != nil { + conn.setup(nil, err) + } else if response.StatusCode != 200 { + response.Body.Close() + conn.setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status)) + } else { + conn.setup(response.Body, nil) + } + }() + return conn, nil +} + +func (s *h2MuxClientSession) Accept() (net.Conn, error) { + return nil, os.ErrInvalid +} + +func (s *h2MuxClientSession) NumStreams() int { + return s.clientConn.State().StreamsActive +} + +func (s *h2MuxClientSession) Close() error { + select { + case <-s.done: + default: + close(s.done) + } + return s.clientConn.Close() +} + +func (s *h2MuxClientSession) IsClosed() bool { + select { + case <-s.done: + return true + default: + } + return s.clientConn.State().Closed +} + +func (s *h2MuxClientSession) CanTakeNewRequest() bool { + return s.clientConn.CanTakeNewRequest() +} diff --git a/h2mux_conn.go b/h2mux_conn.go new file mode 100644 index 0000000..293e3a5 --- /dev/null +++ b/h2mux_conn.go @@ -0,0 +1,82 @@ +package mux + +import ( + "io" + "net" + "os" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/baderror" +) + +type httpConn struct { + reader io.Reader + writer io.Writer + create chan struct{} + err error +} + +func newHTTPConn(reader io.Reader, writer io.Writer) *httpConn { + return &httpConn{ + reader: reader, + writer: writer, + } +} + +func newLateHTTPConn(writer io.Writer) *httpConn { + return &httpConn{ + create: make(chan struct{}), + writer: writer, + } +} + +func (c *httpConn) setup(reader io.Reader, err error) { + c.reader = reader + c.err = err + close(c.create) +} + +func (c *httpConn) Read(b []byte) (n int, err error) { + if c.reader == nil { + <-c.create + if c.err != nil { + return 0, c.err + } + } + n, err = c.reader.Read(b) + return n, baderror.WrapH2(err) +} + +func (c *httpConn) Write(b []byte) (n int, err error) { + n, err = c.writer.Write(b) + return n, baderror.WrapH2(err) +} + +func (c *httpConn) Close() error { + return common.Close(c.reader, c.writer) +} + +func (c *httpConn) LocalAddr() net.Addr { + return nil +} + +func (c *httpConn) RemoteAddr() net.Addr { + return nil +} + +func (c *httpConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *httpConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *httpConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *httpConn) NeedAdditionalReadDeadline() bool { + return true +} diff --git a/padding.go b/padding.go new file mode 100644 index 0000000..850bf25 --- /dev/null +++ b/padding.go @@ -0,0 +1,240 @@ +package mux + +import ( + "encoding/binary" + "io" + "math/rand" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +const kFirstPaddings = 16 + +type paddingConn struct { + N.ExtendedConn + writer N.VectorisedWriter + readPadding int + writePadding int + readRemaining int + paddingRemaining int +} + +func newPaddingConn(conn net.Conn) net.Conn { + writer, isVectorised := bufio.CreateVectorisedWriter(conn) + if isVectorised { + return &vectorisedPaddingConn{ + paddingConn{ + ExtendedConn: bufio.NewExtendedConn(conn), + writer: bufio.NewVectorisedWriter(conn), + }, + writer, + } + } else { + return &paddingConn{ + ExtendedConn: bufio.NewExtendedConn(conn), + writer: bufio.NewVectorisedWriter(conn), + } + } +} + +func (c *paddingConn) Read(p []byte) (n int, err error) { + if c.readRemaining > 0 { + if len(p) > c.readRemaining { + p = p[:c.readRemaining] + } + n, err = c.ExtendedConn.Read(p) + if err != nil { + return + } + c.readRemaining -= n + return + } + if c.paddingRemaining > 0 { + err = rw.SkipN(c.ExtendedConn, c.paddingRemaining) + if err != nil { + return + } + c.paddingRemaining = 0 + } + if c.readPadding < kFirstPaddings { + var paddingHdr []byte + if len(p) >= 4 { + paddingHdr = p[:4] + } else { + _paddingHdr := make([]byte, 4) + defer common.KeepAlive(_paddingHdr) + paddingHdr = common.Dup(_paddingHdr) + } + _, err = io.ReadFull(c.ExtendedConn, paddingHdr) + if err != nil { + return + } + originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) + paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:])) + if len(p) > originalDataSize { + p = p[:originalDataSize] + } + n, err = c.ExtendedConn.Read(p) + if err != nil { + return + } + c.readPadding++ + c.readRemaining = originalDataSize - n + c.paddingRemaining = paddingLen + return + } + return c.ExtendedConn.Read(p) +} + +func (c *paddingConn) Write(p []byte) (n int, err error) { + for pLen := len(p); pLen > 0; { + var data []byte + if pLen > 65535 { + data = p[:65535] + p = p[65535:] + pLen -= 65535 + } else { + data = p + pLen = 0 + } + var writeN int + writeN, err = c.write(data) + n += writeN + if err != nil { + break + } + } + return n, err +} + +func (c *paddingConn) write(p []byte) (n int, err error) { + if c.writePadding < kFirstPaddings { + paddingLen := 256 + rand.Intn(512) + _buffer := buf.StackNewSize(4 + len(p) + paddingLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + header := buffer.Extend(4) + binary.BigEndian.PutUint16(header[:2], uint16(len(p))) + binary.BigEndian.PutUint16(header[2:], uint16(paddingLen)) + common.Must1(buffer.Write(p)) + buffer.Extend(paddingLen) + _, err = c.ExtendedConn.Write(buffer.Bytes()) + if err == nil { + n = len(p) + } + c.writePadding++ + return + } + return c.ExtendedConn.Write(p) +} + +func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error { + p := buffer.FreeBytes() + if c.readRemaining > 0 { + if len(p) > c.readRemaining { + p = p[:c.readRemaining] + } + n, err := c.ExtendedConn.Read(p) + if err != nil { + return err + } + c.readRemaining -= n + buffer.Truncate(n) + return nil + } + if c.paddingRemaining > 0 { + err := rw.SkipN(c.ExtendedConn, c.paddingRemaining) + if err != nil { + return err + } + c.paddingRemaining = 0 + } + if c.readPadding < kFirstPaddings { + var paddingHdr []byte + if len(p) >= 4 { + paddingHdr = p[:4] + } else { + _paddingHdr := make([]byte, 4) + defer common.KeepAlive(_paddingHdr) + paddingHdr = common.Dup(_paddingHdr) + } + _, err := io.ReadFull(c.ExtendedConn, paddingHdr) + if err != nil { + return err + } + originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) + paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:])) + + if len(p) > originalDataSize { + p = p[:originalDataSize] + } + n, err := c.ExtendedConn.Read(p) + if err != nil { + return err + } + c.readPadding++ + c.readRemaining = originalDataSize - n + c.paddingRemaining = paddingLen + buffer.Truncate(n) + return nil + } + return c.ExtendedConn.ReadBuffer(buffer) +} + +func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error { + if c.writePadding < kFirstPaddings { + bufferLen := buffer.Len() + if bufferLen > 65535 { + return common.Error(c.Write(buffer.Bytes())) + } + paddingLen := 256 + rand.Intn(512) + header := buffer.ExtendHeader(4) + binary.BigEndian.PutUint16(header[:2], uint16(bufferLen)) + binary.BigEndian.PutUint16(header[2:], uint16(paddingLen)) + buffer.Extend(paddingLen) + c.writePadding++ + } + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *paddingConn) FrontHeadroom() int { + return 4 + 256 + 1024 +} + +type vectorisedPaddingConn struct { + paddingConn + writer N.VectorisedWriter +} + +func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error { + if c.writePadding < kFirstPaddings { + bufferLen := buf.LenMulti(buffers) + if bufferLen > 65535 { + defer buf.ReleaseMulti(buffers) + for _, buffer := range buffers { + _, err := c.Write(buffer.Bytes()) + if err != nil { + return err + } + } + return nil + } + paddingLen := 256 + rand.Intn(512) + header := buf.NewSize(4) + common.Must( + binary.Write(header, binary.BigEndian, uint16(bufferLen)), + binary.Write(header, binary.BigEndian, uint16(paddingLen)), + ) + c.writePadding++ + padding := buf.NewSize(paddingLen) + padding.Extend(paddingLen) + buffers = append(append([]*buf.Buffer{header}, buffers...), padding) + } + return c.writer.WriteVectorised(buffers) +} diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..106b39c --- /dev/null +++ b/protocol.go @@ -0,0 +1,183 @@ +package mux + +import ( + "encoding/binary" + "io" + "math/rand" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +const ( + ProtocolSmux = iota + ProtocolYAMux + ProtocolH2Mux +) + +const ( + Version0 = iota + Version1 +) + +const ( + TCPTimeout = 5 * time.Second +) + +var Destination = M.Socksaddr{ + Fqdn: "sp.mux.sing-box.arpa", + Port: 444, +} + +type Request struct { + Version byte + Protocol byte + Padding bool +} + +func ReadRequest(reader io.Reader) (*Request, error) { + version, err := rw.ReadByte(reader) + if err != nil { + return nil, err + } + if version < Version0 || version > Version1 { + return nil, E.New("unsupported version: ", version) + } + protocol, err := rw.ReadByte(reader) + if err != nil { + return nil, err + } + var paddingEnabled bool + if version == Version1 { + err = binary.Read(reader, binary.BigEndian, &paddingEnabled) + if err != nil { + return nil, err + } + if paddingEnabled { + var paddingLen uint16 + err = binary.Read(reader, binary.BigEndian, &paddingLen) + if err != nil { + return nil, err + } + err = rw.SkipN(reader, int(paddingLen)) + if err != nil { + return nil, err + } + } + } + return &Request{Version: version, Protocol: protocol, Padding: paddingEnabled}, nil +} + +func EncodeRequest(request Request, payload []byte) *buf.Buffer { + var requestLen int + requestLen += 2 + var paddingLen uint16 + if request.Version == Version1 { + requestLen += 1 + if request.Padding { + requestLen += 2 + paddingLen = uint16(256 + rand.Intn(512)) + requestLen += int(paddingLen) + } + } + buffer := buf.NewSize(requestLen + len(payload)) + common.Must( + buffer.WriteByte(request.Version), + buffer.WriteByte(request.Protocol), + ) + if request.Version == Version1 { + common.Must(binary.Write(buffer, binary.BigEndian, request.Padding)) + if request.Padding { + common.Must(binary.Write(buffer, binary.BigEndian, paddingLen)) + buffer.Extend(int(paddingLen)) + } + } + common.Must1(buffer.Write(payload)) + return buffer +} + +const ( + flagUDP = 1 + flagAddr = 2 + statusSuccess = 0 + statusError = 1 +) + +type StreamRequest struct { + Network string + Destination M.Socksaddr + PacketAddr bool +} + +func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) { + var flags uint16 + err := binary.Read(reader, binary.BigEndian, &flags) + if err != nil { + return nil, err + } + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) + if err != nil { + return nil, err + } + var network string + var udpAddr bool + if flags&flagUDP == 0 { + network = N.NetworkTCP + } else { + network = N.NetworkUDP + udpAddr = flags&flagAddr != 0 + } + return &StreamRequest{network, destination, udpAddr}, nil +} + +func streamRequestLen(request StreamRequest) int { + var rLen int + rLen += 1 // version + rLen += 2 // flags + rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination) + return rLen +} + +func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) { + destination := request.Destination + var flags uint16 + if request.Network == N.NetworkUDP { + flags |= flagUDP + } + if request.PacketAddr { + flags |= flagAddr + if !destination.IsValid() { + destination = Destination + } + } + common.Must( + binary.Write(buffer, binary.BigEndian, flags), + M.SocksaddrSerializer.WriteAddrPort(buffer, destination), + ) +} + +type StreamResponse struct { + Status uint8 + Message string +} + +func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) { + var response StreamResponse + status, err := rw.ReadByte(reader) + if err != nil { + return nil, err + } + response.Status = status + if status == statusError { + response.Message, err = rw.ReadVString(reader) + if err != nil { + return nil, err + } + } + return &response, nil +} diff --git a/protocol_conn.go b/protocol_conn.go new file mode 100644 index 0000000..ec3d2a4 --- /dev/null +++ b/protocol_conn.go @@ -0,0 +1,73 @@ +package mux + +import ( + "io" + "net" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + N "github.com/sagernet/sing/common/network" +) + +type protocolConn struct { + net.Conn + request Request + requestWritten bool +} + +func newProtocolConn(conn net.Conn, request Request) net.Conn { + writer, isVectorised := bufio.CreateVectorisedWriter(conn) + if isVectorised { + return &vectorisedProtocolConn{ + protocolConn{ + Conn: conn, + request: request, + }, + writer, + } + } else { + return &protocolConn{ + Conn: conn, + request: request, + } + } +} + +func (c *protocolConn) Write(p []byte) (n int, err error) { + if c.requestWritten { + return c.Conn.Write(p) + } + buffer := EncodeRequest(c.request, p) + n, err = c.Conn.Write(buffer.Bytes()) + buffer.Release() + if err == nil { + n-- + } + c.requestWritten = true + return n, err +} + +func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.requestWritten { + return bufio.ReadFrom0(c, r) + } + return bufio.Copy(c.Conn, r) +} + +func (c *protocolConn) Upstream() any { + return c.Conn +} + +type vectorisedProtocolConn struct { + protocolConn + writer N.VectorisedWriter +} + +func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error { + if c.requestWritten { + return c.writer.WriteVectorised(buffers) + } + c.requestWritten = true + buffer := EncodeRequest(c.request, nil) + return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...)) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..657ff75 --- /dev/null +++ b/server.go @@ -0,0 +1,80 @@ +package mux + +import ( + "context" + "net" + + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/task" +) + +type ServerHandler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler + E.Handler +} + +func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, conn net.Conn, metadata M.Metadata) error { + request, err := ReadRequest(conn) + if err != nil { + return err + } + if request.Padding { + conn = newPaddingConn(conn) + } + session, err := newServerSession(conn, request.Protocol) + if err != nil { + return err + } + var group task.Group + group.Append0(func(ctx context.Context) error { + var stream net.Conn + for { + stream, err = session.Accept() + if err != nil { + return err + } + go newConnection(ctx, handler, logger, stream, metadata) + } + }) + group.Cleanup(func() { + session.Close() + }) + return group.Run(ctx) +} + +func newConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, stream net.Conn, metadata M.Metadata) { + stream = &wrapStream{stream} + request, err := ReadStreamRequest(stream) + if err != nil { + logger.ErrorContext(ctx, err) + return + } + metadata.Destination = request.Destination + if request.Network == N.NetworkTCP { + logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) + hErr := handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata) + stream.Close() + if hErr != nil { + handler.NewError(ctx, hErr) + } + } else { + var packetConn N.PacketConn + if !request.PacketAddr { + logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) + packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} + } else { + logger.InfoContext(ctx, "inbound multiplex packet connection") + packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} + } + hErr := handler.NewPacketConnection(ctx, packetConn, metadata) + stream.Close() + if hErr != nil { + handler.NewError(ctx, hErr) + } + } +} diff --git a/server_conn.go b/server_conn.go new file mode 100644 index 0000000..52e9db7 --- /dev/null +++ b/server_conn.go @@ -0,0 +1,204 @@ +package mux + +import ( + "encoding/binary" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +var _ N.HandshakeConn = (*serverConn)(nil) + +type serverConn struct { + N.ExtendedConn + responseWritten bool +} + +func (c *serverConn) HandshakeFailure(err error) error { + errMessage := err.Error() + _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + common.Must( + buffer.WriteByte(statusError), + rw.WriteVString(_buffer, errMessage), + ) + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *serverConn) Write(b []byte) (n int, err error) { + if c.responseWritten { + return c.ExtendedConn.Write(b) + } + _buffer := buf.StackNewSize(1 + len(b)) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + common.Must( + buffer.WriteByte(statusSuccess), + common.Error(buffer.Write(b)), + ) + _, err = c.ExtendedConn.Write(buffer.Bytes()) + if err != nil { + return + } + c.responseWritten = true + return len(b), nil +} + +func (c *serverConn) WriteBuffer(buffer *buf.Buffer) error { + if c.responseWritten { + return c.ExtendedConn.WriteBuffer(buffer) + } + buffer.ExtendHeader(1)[0] = statusSuccess + c.responseWritten = true + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *serverConn) FrontHeadroom() int { + if !c.responseWritten { + return 1 + } + return 0 +} + +func (c *serverConn) NeedAdditionalReadDeadline() bool { + return true +} + +func (c *serverConn) Upstream() any { + return c.ExtendedConn +} + +var ( + _ N.HandshakeConn = (*serverPacketConn)(nil) + _ N.PacketConn = (*serverPacketConn)(nil) +) + +type serverPacketConn struct { + N.ExtendedConn + destination M.Socksaddr + responseWritten bool +} + +func (c *serverPacketConn) HandshakeFailure(err error) error { + errMessage := err.Error() + _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + common.Must( + buffer.WriteByte(statusError), + rw.WriteVString(_buffer, errMessage), + ) + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) + if err != nil { + return + } + destination = c.destination + return +} + +func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + pLen := buffer.Len() + common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) + if !c.responseWritten { + buffer.ExtendHeader(1)[0] = statusSuccess + c.responseWritten = true + } + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *serverPacketConn) NeedAdditionalReadDeadline() bool { + return true +} + +func (c *serverPacketConn) Upstream() any { + return c.ExtendedConn +} + +func (c *serverPacketConn) FrontHeadroom() int { + if !c.responseWritten { + return 3 + } + return 2 +} + +var ( + _ N.HandshakeConn = (*serverPacketAddrConn)(nil) + _ N.PacketConn = (*serverPacketAddrConn)(nil) +) + +type serverPacketAddrConn struct { + N.ExtendedConn + responseWritten bool +} + +func (c *serverPacketAddrConn) HandshakeFailure(err error) error { + errMessage := err.Error() + _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + common.Must( + buffer.WriteByte(statusError), + rw.WriteVString(_buffer, errMessage), + ) + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *serverPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) + if err != nil { + return + } + var length uint16 + err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + if err != nil { + return + } + _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) + if err != nil { + return + } + return +} + +func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + pLen := buffer.Len() + common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) + common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) + if !c.responseWritten { + buffer.ExtendHeader(1)[0] = statusSuccess + c.responseWritten = true + } + return c.ExtendedConn.WriteBuffer(buffer) +} + +func (c *serverPacketAddrConn) NeedAdditionalReadDeadline() bool { + return true +} + +func (c *serverPacketAddrConn) Upstream() any { + return c.ExtendedConn +} + +func (c *serverPacketAddrConn) FrontHeadroom() int { + if !c.responseWritten { + return 3 + M.MaxSocksaddrLength + } + return 2 + M.MaxSocksaddrLength +} diff --git a/server_default.go b/server_default.go new file mode 100644 index 0000000..f10247e --- /dev/null +++ b/server_default.go @@ -0,0 +1,36 @@ +package mux + +import ( + "context" + "net" + + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func HandleConnectionDefault(ctx context.Context, conn net.Conn) error { + return HandleConnection(ctx, (*defaultServerHandler)(nil), logger.NOP(), conn, M.Metadata{}) +} + +type defaultServerHandler struct{} + +func (h *defaultServerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + remoteConn, err := N.SystemDialer.DialContext(ctx, N.NetworkTCP, metadata.Destination) + if err != nil { + return err + } + return bufio.CopyConn(ctx, conn, remoteConn) +} + +func (h *defaultServerHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { + remoteConn, err := N.SystemDialer.ListenPacket(ctx, metadata.Destination) + if err != nil { + return err + } + return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(remoteConn)) +} + +func (h *defaultServerHandler) NewError(ctx context.Context, err error) { +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..2e674c8 --- /dev/null +++ b/session.go @@ -0,0 +1,106 @@ +package mux + +import ( + "io" + "net" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/smux" + + "github.com/hashicorp/yamux" +) + +type abstractSession interface { + Open() (net.Conn, error) + Accept() (net.Conn, error) + NumStreams() int + Close() error + IsClosed() bool + CanTakeNewRequest() bool +} + +func newClientSession(conn net.Conn, protocol byte) (abstractSession, error) { + switch protocol { + case ProtocolH2Mux: + session, err := newH2MuxClient(conn) + if err != nil { + return nil, err + } + return session, nil + case ProtocolSmux: + client, err := smux.Client(conn, smuxConfig()) + if err != nil { + return nil, err + } + return &smuxSession{client}, nil + case ProtocolYAMux: + client, err := yamux.Client(conn, yaMuxConfig()) + if err != nil { + return nil, err + } + return &yamuxSession{client}, nil + default: + return nil, E.New("unexpected protocol ", protocol) + } +} + +func newServerSession(conn net.Conn, protocol byte) (abstractSession, error) { + switch protocol { + case ProtocolH2Mux: + return newH2MuxServer(conn), nil + case ProtocolSmux: + client, err := smux.Server(conn, smuxConfig()) + if err != nil { + return nil, err + } + return &smuxSession{client}, nil + case ProtocolYAMux: + client, err := yamux.Server(conn, yaMuxConfig()) + if err != nil { + return nil, err + } + return &yamuxSession{client}, nil + default: + return nil, E.New("unexpected protocol ", protocol) + } +} + +var _ abstractSession = (*smuxSession)(nil) + +type smuxSession struct { + *smux.Session +} + +func (s *smuxSession) Open() (net.Conn, error) { + return s.OpenStream() +} + +func (s *smuxSession) Accept() (net.Conn, error) { + return s.AcceptStream() +} + +func (s *smuxSession) CanTakeNewRequest() bool { + return true +} + +type yamuxSession struct { + *yamux.Session +} + +func (y *yamuxSession) CanTakeNewRequest() bool { + return true +} + +func smuxConfig() *smux.Config { + config := smux.DefaultConfig() + config.KeepAliveDisabled = true + return config +} + +func yaMuxConfig() *yamux.Config { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + config.StreamCloseTimeout = TCPTimeout + config.StreamOpenTimeout = TCPTimeout + return config +}