Init commit

This commit is contained in:
世界 2023-04-23 17:05:39 +08:00
commit b8acf3f145
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
22 changed files with 2021 additions and 0 deletions

5
.github/update_dependencies.sh vendored Executable file
View file

@ -0,0 +1,5 @@
#!/usr/bin/env bash
PROJECTS=$(dirname "$0")/../..
go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD)
go mod tidy

43
.github/workflows/debug.yml vendored Normal file
View file

@ -0,0 +1,43 @@
name: Debug build
on:
push:
branches:
- main
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
jobs:
build:
name: Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Add cache to Go proxy
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing-mux@$version
popd
continue-on-error: true
- name: Build
run: |
make test

41
.github/workflows/lint.yml vendored Normal file
View file

@ -0,0 +1,41 @@
name: Lint
on:
push:
branches:
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/lint.yml'
pull_request:
branches:
- dev
jobs:
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Cache go module
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
key: go-${{ hashFiles('**/go.sum') }}
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: latest

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/.idea/
/vendor/

17
.golangci.yml Normal file
View file

@ -0,0 +1,17 @@
linters:
disable-all: true
enable:
- gofumpt
- govet
- gci
- staticcheck
linters-settings:
gci:
custom-order: true
sections:
- standard
- prefix(github.com/sagernet/)
- default
staticcheck:
go: '1.20'

14
LICENSE Normal file
View file

@ -0,0 +1,14 @@
Copyright (C) 2022 by nekohasekai <contact-sagernet@sekai.icu>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.

21
Makefile Normal file
View file

@ -0,0 +1,21 @@
fmt:
@gofumpt -l -w .
@gofmt -s -w .
@gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" .
fmt_install:
go install -v mvdan.cc/gofumpt@latest
go install -v github.com/daixiang0/gci@latest
lint:
GOOS=linux golangci-lint run ./...
GOOS=android golangci-lint run ./...
GOOS=windows golangci-lint run ./...
GOOS=darwin golangci-lint run ./...
GOOS=freebsd golangci-lint run ./...
lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go test -v ./...

3
README.md Normal file
View file

@ -0,0 +1,3 @@
# sing-mux
Simple multiplex library.

183
client.go Normal file
View file

@ -0,0 +1,183 @@
package mux
import (
"context"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
type Client struct {
dialer N.Dialer
protocol byte
maxConnections int
minStreams int
maxStreams int
padding bool
access sync.Mutex
connections list.List[abstractSession]
}
type Options struct {
Dialer N.Dialer
Protocol string
MaxConnections int
MinStreams int
MaxStreams int
Padding bool
}
func NewClient(options Options) (*Client, error) {
client := &Client{
dialer: options.Dialer,
maxConnections: options.MaxConnections,
minStreams: options.MinStreams,
maxStreams: options.MaxStreams,
padding: options.Padding,
}
if client.dialer == nil {
client.dialer = N.SystemDialer
}
if client.maxStreams == 0 && client.maxConnections == 0 {
client.minStreams = 8
}
switch options.Protocol {
case "", "h2mux":
client.protocol = ProtocolH2Mux
case "smux":
client.protocol = ProtocolSmux
case "yamux":
client.protocol = ProtocolYAMux
default:
return nil, E.New("unknown protocol: " + options.Protocol)
}
return client, nil
}
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
stream, err := c.openStream(ctx)
if err != nil {
return nil, err
}
return &clientConn{Conn: stream, destination: destination}, nil
case N.NetworkUDP:
stream, err := c.openStream(ctx)
if err != nil {
return nil, err
}
return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
}
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
stream, err := c.openStream(ctx)
if err != nil {
return nil, err
}
return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
}
func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
var (
session abstractSession
stream net.Conn
err error
)
for attempts := 0; attempts < 2; attempts++ {
session, err = c.offer(ctx)
if err != nil {
continue
}
stream, err = session.Open()
if err != nil {
continue
}
break
}
if err != nil {
return nil, err
}
return &wrapStream{stream}, nil
}
func (c *Client) offer(ctx context.Context) (abstractSession, error) {
c.access.Lock()
defer c.access.Unlock()
sessions := make([]abstractSession, 0, c.maxConnections)
for element := c.connections.Front(); element != nil; {
if element.Value.IsClosed() {
nextElement := element.Next()
c.connections.Remove(element)
element = nextElement
continue
}
sessions = append(sessions, element.Value)
element = element.Next()
}
session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams)
if session == nil {
return c.offerNew(ctx)
}
numStreams := session.NumStreams()
if numStreams == 0 {
return session, nil
}
if c.maxConnections > 0 {
if len(sessions) >= c.maxConnections || numStreams < c.minStreams {
return session, nil
}
} else {
if c.maxStreams > 0 && numStreams < c.maxStreams {
return session, nil
}
}
return c.offerNew(ctx)
}
func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination)
if err != nil {
return nil, err
}
var version byte
if c.padding {
version = Version1
} else {
version = Version0
}
conn = newProtocolConn(conn, Request{
Version: version,
Protocol: c.protocol,
Padding: c.padding,
})
if c.padding {
conn = newPaddingConn(conn)
}
session, err := newClientSession(conn, c.protocol)
if err != nil {
conn.Close()
return nil, err
}
c.connections.PushBack(session)
return session, nil
}
func (c *Client) Reset() {
c.access.Lock()
defer c.access.Unlock()
for _, session := range c.connections.Array() {
session.Close()
}
c.connections.Init()
}

380
client_conn.go Normal file
View file

@ -0,0 +1,380 @@
package mux
import (
"encoding/binary"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type clientConn struct {
net.Conn
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func (c *clientConn) readResponse() error {
response, err := ReadStreamResponse(c.Conn)
if err != nil {
return err
}
if response.Status == statusError {
return E.New("remote error: ", response.Message)
}
return nil
}
func (c *clientConn) Read(b []byte) (n int, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
return c.Conn.Read(b)
}
func (c *clientConn) Write(b []byte) (n int, err error) {
if c.requestWritten {
return c.Conn.Write(b)
}
request := StreamRequest{
Network: N.NetworkTCP,
Destination: c.destination,
}
_buffer := buf.StackNewSize(streamRequestLen(request) + len(b))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeStreamRequest(request, buffer)
buffer.Write(b)
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(b), nil
}
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.requestWritten {
return bufio.ReadFrom0(c, r)
}
return bufio.Copy(c.Conn, r)
}
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
if !c.responseRead {
return bufio.WriteTo0(c, w)
}
return bufio.Copy(w, c.Conn)
}
func (c *clientConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *clientConn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
func (c *clientConn) ReaderReplaceable() bool {
return c.responseRead
}
func (c *clientConn) WriterReplaceable() bool {
return c.requestWritten
}
func (c *clientConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *clientConn) Upstream() any {
return c.Conn
}
type clientPacketConn struct {
N.ExtendedConn
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func (c *clientPacketConn) readResponse() error {
response, err := ReadStreamResponse(c.ExtendedConn)
if err != nil {
return err
}
if response.Status == statusError {
return E.New("remote error: ", response.Message)
}
return nil
}
func (c *clientPacketConn) Read(b []byte) (n int, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(b) < int(length) {
return 0, io.ErrShortBuffer
}
return io.ReadFull(c.ExtendedConn, b[:length])
}
func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
request := StreamRequest{
Network: N.NetworkUDP,
Destination: c.destination,
}
rLen := streamRequestLen(request)
if len(payload) > 0 {
rLen += 2 + len(payload)
}
_buffer := buf.StackNewSize(rLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeStreamRequest(request, buffer)
if len(payload) > 0 {
common.Must(
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
common.Error(buffer.Write(payload)),
)
}
_, err = c.ExtendedConn.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(payload), nil
}
func (c *clientPacketConn) Write(b []byte) (n int, err error) {
if !c.requestWritten {
return c.writeRequest(b)
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
if err != nil {
return
}
return c.ExtendedConn.Write(b)
}
func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
return
}
func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
if !c.requestWritten {
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes()))
}
bLen := buffer.Len()
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *clientPacketConn) FrontHeadroom() int {
return 2
}
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
err = c.ReadBuffer(buffer)
return
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.WriteBuffer(buffer)
}
func (c *clientPacketConn) LocalAddr() net.Addr {
return c.ExtendedConn.LocalAddr()
}
func (c *clientPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *clientPacketConn) Upstream() any {
return c.ExtendedConn
}
var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
type clientPacketAddrConn struct {
N.ExtendedConn
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func (c *clientPacketAddrConn) readResponse() error {
response, err := ReadStreamResponse(c.ExtendedConn)
if err != nil {
return err
}
if response.Status == statusError {
return E.New("remote error: ", response.Message)
}
return nil
}
func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
if err != nil {
return
}
if destination.IsFqdn() {
addr = destination
} else {
addr = destination.UDPAddr()
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(p) < int(length) {
return 0, nil, io.ErrShortBuffer
}
n, err = io.ReadFull(c.ExtendedConn, p[:length])
return
}
func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
request := StreamRequest{
Network: N.NetworkUDP,
Destination: c.destination,
PacketAddr: true,
}
rLen := streamRequestLen(request)
if len(payload) > 0 {
rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
}
_buffer := buf.StackNewSize(rLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeStreamRequest(request, buffer)
if len(payload) > 0 {
common.Must(
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
common.Error(buffer.Write(payload)),
)
}
_, err = c.ExtendedConn.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(payload), nil
}
func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten {
return c.writeRequest(p, M.SocksaddrFromNet(addr))
}
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
if err != nil {
return
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
if err != nil {
return
}
return c.ExtendedConn.Write(p)
}
func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
return
}
func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.requestWritten {
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes(), destination))
}
bLen := buffer.Len()
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
common.Must(
M.SocksaddrSerializer.WriteAddrPort(header, destination),
binary.Write(header, binary.BigEndian, uint16(bLen)),
)
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *clientPacketAddrConn) LocalAddr() net.Addr {
return c.ExtendedConn.LocalAddr()
}
func (c *clientPacketAddrConn) FrontHeadroom() int {
return 2 + M.MaxSocksaddrLength
}
func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *clientPacketAddrConn) Upstream() any {
return c.ExtendedConn
}

37
error.go Normal file
View file

@ -0,0 +1,37 @@
package mux
import (
"io"
"net"
"github.com/hashicorp/yamux"
)
type wrapStream struct {
net.Conn
}
func (w *wrapStream) Read(p []byte) (n int, err error) {
n, err = w.Conn.Read(p)
err = wrapError(err)
return
}
func (w *wrapStream) Write(p []byte) (n int, err error) {
n, err = w.Conn.Write(p)
err = wrapError(err)
return
}
func (w *wrapStream) Upstream() any {
return w.Conn
}
func wrapError(err error) error {
switch err {
case yamux.ErrStreamClosed:
return io.EOF
default:
return err
}
}

15
go.mod Normal file
View file

@ -0,0 +1,15 @@
module github.com/sagernet/sing-mux
go 1.18
require (
github.com/hashicorp/yamux v0.1.1
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
golang.org/x/net v0.9.0
)
require (
golang.org/x/sys v0.7.0 // indirect
golang.org/x/text v0.9.0 // indirect
)

14
go.sum Normal file
View file

@ -0,0 +1,14 @@
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss=
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=

242
h2mux.go Normal file
View file

@ -0,0 +1,242 @@
package mux
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
"os"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/http2"
)
const idleTimeout = 30 * time.Second
var _ abstractSession = (*h2MuxServerSession)(nil)
type h2MuxServerSession struct {
server http2.Server
active atomic.Int32
conn net.Conn
inbound chan net.Conn
done chan struct{}
}
func newH2MuxServer(conn net.Conn) *h2MuxServerSession {
session := &h2MuxServerSession{
conn: conn,
inbound: make(chan net.Conn),
done: make(chan struct{}),
server: http2.Server{
IdleTimeout: idleTimeout,
},
}
go func() {
session.server.ServeConn(conn, &http2.ServeConnOpts{
Handler: session,
})
_ = session.Close()
}()
return session
}
func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
s.active.Add(1)
defer s.active.Add(-1)
writer.WriteHeader(http.StatusOK)
conn := newHTTP2Wrapper(newHTTPConn(request.Body, writer), writer.(http.Flusher))
s.inbound <- conn
select {
case <-conn.done:
case <-s.done:
_ = conn.Close()
}
}
func (s *h2MuxServerSession) Open() (net.Conn, error) {
return nil, os.ErrInvalid
}
func (s *h2MuxServerSession) Accept() (net.Conn, error) {
select {
case conn := <-s.inbound:
return conn, nil
case <-s.done:
return nil, os.ErrClosed
}
}
func (s *h2MuxServerSession) NumStreams() int {
return int(s.active.Load())
}
func (s *h2MuxServerSession) Close() error {
select {
case <-s.done:
default:
close(s.done)
}
return s.conn.Close()
}
func (s *h2MuxServerSession) IsClosed() bool {
select {
case <-s.done:
return true
default:
return false
}
}
func (s *h2MuxServerSession) CanTakeNewRequest() bool {
return false
}
type h2MuxConnWrapper struct {
N.ExtendedConn
flusher http.Flusher
done chan struct{}
}
func newHTTP2Wrapper(conn net.Conn, flusher http.Flusher) *h2MuxConnWrapper {
return &h2MuxConnWrapper{
ExtendedConn: bufio.NewExtendedConn(conn),
flusher: flusher,
done: make(chan struct{}),
}
}
func (w *h2MuxConnWrapper) Write(p []byte) (n int, err error) {
select {
case <-w.done:
return 0, net.ErrClosed
default:
}
n, err = w.ExtendedConn.Write(p)
if err == nil {
w.flusher.Flush()
}
return
}
func (w *h2MuxConnWrapper) WriteBuffer(buffer *buf.Buffer) error {
select {
case <-w.done:
return net.ErrClosed
default:
}
err := w.ExtendedConn.WriteBuffer(buffer)
if err == nil {
w.flusher.Flush()
}
return err
}
func (w *h2MuxConnWrapper) Close() error {
select {
case <-w.done:
default:
close(w.done)
}
return w.ExtendedConn.Close()
}
func (w *h2MuxConnWrapper) Upstream() any {
return w.ExtendedConn
}
var _ abstractSession = (*h2MuxClientSession)(nil)
type h2MuxClientSession struct {
transport *http2.Transport
clientConn *http2.ClientConn
done chan struct{}
}
func newH2MuxClient(conn net.Conn) (*h2MuxClientSession, error) {
session := &h2MuxClientSession{
transport: &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
return conn, nil
},
ReadIdleTimeout: idleTimeout,
},
done: make(chan struct{}),
}
session.transport.ConnPool = session
clientConn, err := session.transport.NewClientConn(conn)
if err != nil {
return nil, err
}
session.clientConn = clientConn
return session, nil
}
func (s *h2MuxClientSession) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) {
return s.clientConn, nil
}
func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) {
s.Close()
}
func (s *h2MuxClientSession) Open() (net.Conn, error) {
pipeInReader, pipeInWriter := io.Pipe()
request := &http.Request{
Method: http.MethodConnect,
Body: pipeInReader,
URL: &url.URL{Scheme: "https", Host: "localhost"},
}
conn := newLateHTTPConn(pipeInWriter)
go func() {
response, err := s.transport.RoundTrip(request)
if err != nil {
conn.setup(nil, err)
} else if response.StatusCode != 200 {
response.Body.Close()
conn.setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status))
} else {
conn.setup(response.Body, nil)
}
}()
return conn, nil
}
func (s *h2MuxClientSession) Accept() (net.Conn, error) {
return nil, os.ErrInvalid
}
func (s *h2MuxClientSession) NumStreams() int {
return s.clientConn.State().StreamsActive
}
func (s *h2MuxClientSession) Close() error {
select {
case <-s.done:
default:
close(s.done)
}
return s.clientConn.Close()
}
func (s *h2MuxClientSession) IsClosed() bool {
select {
case <-s.done:
return true
default:
}
return s.clientConn.State().Closed
}
func (s *h2MuxClientSession) CanTakeNewRequest() bool {
return s.clientConn.CanTakeNewRequest()
}

82
h2mux_conn.go Normal file
View file

@ -0,0 +1,82 @@
package mux
import (
"io"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/baderror"
)
type httpConn struct {
reader io.Reader
writer io.Writer
create chan struct{}
err error
}
func newHTTPConn(reader io.Reader, writer io.Writer) *httpConn {
return &httpConn{
reader: reader,
writer: writer,
}
}
func newLateHTTPConn(writer io.Writer) *httpConn {
return &httpConn{
create: make(chan struct{}),
writer: writer,
}
}
func (c *httpConn) setup(reader io.Reader, err error) {
c.reader = reader
c.err = err
close(c.create)
}
func (c *httpConn) Read(b []byte) (n int, err error) {
if c.reader == nil {
<-c.create
if c.err != nil {
return 0, c.err
}
}
n, err = c.reader.Read(b)
return n, baderror.WrapH2(err)
}
func (c *httpConn) Write(b []byte) (n int, err error) {
n, err = c.writer.Write(b)
return n, baderror.WrapH2(err)
}
func (c *httpConn) Close() error {
return common.Close(c.reader, c.writer)
}
func (c *httpConn) LocalAddr() net.Addr {
return nil
}
func (c *httpConn) RemoteAddr() net.Addr {
return nil
}
func (c *httpConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *httpConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *httpConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *httpConn) NeedAdditionalReadDeadline() bool {
return true
}

240
padding.go Normal file
View file

@ -0,0 +1,240 @@
package mux
import (
"encoding/binary"
"io"
"math/rand"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
const kFirstPaddings = 16
type paddingConn struct {
N.ExtendedConn
writer N.VectorisedWriter
readPadding int
writePadding int
readRemaining int
paddingRemaining int
}
func newPaddingConn(conn net.Conn) net.Conn {
writer, isVectorised := bufio.CreateVectorisedWriter(conn)
if isVectorised {
return &vectorisedPaddingConn{
paddingConn{
ExtendedConn: bufio.NewExtendedConn(conn),
writer: bufio.NewVectorisedWriter(conn),
},
writer,
}
} else {
return &paddingConn{
ExtendedConn: bufio.NewExtendedConn(conn),
writer: bufio.NewVectorisedWriter(conn),
}
}
}
func (c *paddingConn) Read(p []byte) (n int, err error) {
if c.readRemaining > 0 {
if len(p) > c.readRemaining {
p = p[:c.readRemaining]
}
n, err = c.ExtendedConn.Read(p)
if err != nil {
return
}
c.readRemaining -= n
return
}
if c.paddingRemaining > 0 {
err = rw.SkipN(c.ExtendedConn, c.paddingRemaining)
if err != nil {
return
}
c.paddingRemaining = 0
}
if c.readPadding < kFirstPaddings {
var paddingHdr []byte
if len(p) >= 4 {
paddingHdr = p[:4]
} else {
_paddingHdr := make([]byte, 4)
defer common.KeepAlive(_paddingHdr)
paddingHdr = common.Dup(_paddingHdr)
}
_, err = io.ReadFull(c.ExtendedConn, paddingHdr)
if err != nil {
return
}
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
if len(p) > originalDataSize {
p = p[:originalDataSize]
}
n, err = c.ExtendedConn.Read(p)
if err != nil {
return
}
c.readPadding++
c.readRemaining = originalDataSize - n
c.paddingRemaining = paddingLen
return
}
return c.ExtendedConn.Read(p)
}
func (c *paddingConn) Write(p []byte) (n int, err error) {
for pLen := len(p); pLen > 0; {
var data []byte
if pLen > 65535 {
data = p[:65535]
p = p[65535:]
pLen -= 65535
} else {
data = p
pLen = 0
}
var writeN int
writeN, err = c.write(data)
n += writeN
if err != nil {
break
}
}
return n, err
}
func (c *paddingConn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings {
paddingLen := 256 + rand.Intn(512)
_buffer := buf.StackNewSize(4 + len(p) + paddingLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
header := buffer.Extend(4)
binary.BigEndian.PutUint16(header[:2], uint16(len(p)))
binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
common.Must1(buffer.Write(p))
buffer.Extend(paddingLen)
_, err = c.ExtendedConn.Write(buffer.Bytes())
if err == nil {
n = len(p)
}
c.writePadding++
return
}
return c.ExtendedConn.Write(p)
}
func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error {
p := buffer.FreeBytes()
if c.readRemaining > 0 {
if len(p) > c.readRemaining {
p = p[:c.readRemaining]
}
n, err := c.ExtendedConn.Read(p)
if err != nil {
return err
}
c.readRemaining -= n
buffer.Truncate(n)
return nil
}
if c.paddingRemaining > 0 {
err := rw.SkipN(c.ExtendedConn, c.paddingRemaining)
if err != nil {
return err
}
c.paddingRemaining = 0
}
if c.readPadding < kFirstPaddings {
var paddingHdr []byte
if len(p) >= 4 {
paddingHdr = p[:4]
} else {
_paddingHdr := make([]byte, 4)
defer common.KeepAlive(_paddingHdr)
paddingHdr = common.Dup(_paddingHdr)
}
_, err := io.ReadFull(c.ExtendedConn, paddingHdr)
if err != nil {
return err
}
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
if len(p) > originalDataSize {
p = p[:originalDataSize]
}
n, err := c.ExtendedConn.Read(p)
if err != nil {
return err
}
c.readPadding++
c.readRemaining = originalDataSize - n
c.paddingRemaining = paddingLen
buffer.Truncate(n)
return nil
}
return c.ExtendedConn.ReadBuffer(buffer)
}
func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error {
if c.writePadding < kFirstPaddings {
bufferLen := buffer.Len()
if bufferLen > 65535 {
return common.Error(c.Write(buffer.Bytes()))
}
paddingLen := 256 + rand.Intn(512)
header := buffer.ExtendHeader(4)
binary.BigEndian.PutUint16(header[:2], uint16(bufferLen))
binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
buffer.Extend(paddingLen)
c.writePadding++
}
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *paddingConn) FrontHeadroom() int {
return 4 + 256 + 1024
}
type vectorisedPaddingConn struct {
paddingConn
writer N.VectorisedWriter
}
func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error {
if c.writePadding < kFirstPaddings {
bufferLen := buf.LenMulti(buffers)
if bufferLen > 65535 {
defer buf.ReleaseMulti(buffers)
for _, buffer := range buffers {
_, err := c.Write(buffer.Bytes())
if err != nil {
return err
}
}
return nil
}
paddingLen := 256 + rand.Intn(512)
header := buf.NewSize(4)
common.Must(
binary.Write(header, binary.BigEndian, uint16(bufferLen)),
binary.Write(header, binary.BigEndian, uint16(paddingLen)),
)
c.writePadding++
padding := buf.NewSize(paddingLen)
padding.Extend(paddingLen)
buffers = append(append([]*buf.Buffer{header}, buffers...), padding)
}
return c.writer.WriteVectorised(buffers)
}

183
protocol.go Normal file
View file

@ -0,0 +1,183 @@
package mux
import (
"encoding/binary"
"io"
"math/rand"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
const (
ProtocolSmux = iota
ProtocolYAMux
ProtocolH2Mux
)
const (
Version0 = iota
Version1
)
const (
TCPTimeout = 5 * time.Second
)
var Destination = M.Socksaddr{
Fqdn: "sp.mux.sing-box.arpa",
Port: 444,
}
type Request struct {
Version byte
Protocol byte
Padding bool
}
func ReadRequest(reader io.Reader) (*Request, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version < Version0 || version > Version1 {
return nil, E.New("unsupported version: ", version)
}
protocol, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
var paddingEnabled bool
if version == Version1 {
err = binary.Read(reader, binary.BigEndian, &paddingEnabled)
if err != nil {
return nil, err
}
if paddingEnabled {
var paddingLen uint16
err = binary.Read(reader, binary.BigEndian, &paddingLen)
if err != nil {
return nil, err
}
err = rw.SkipN(reader, int(paddingLen))
if err != nil {
return nil, err
}
}
}
return &Request{Version: version, Protocol: protocol, Padding: paddingEnabled}, nil
}
func EncodeRequest(request Request, payload []byte) *buf.Buffer {
var requestLen int
requestLen += 2
var paddingLen uint16
if request.Version == Version1 {
requestLen += 1
if request.Padding {
requestLen += 2
paddingLen = uint16(256 + rand.Intn(512))
requestLen += int(paddingLen)
}
}
buffer := buf.NewSize(requestLen + len(payload))
common.Must(
buffer.WriteByte(request.Version),
buffer.WriteByte(request.Protocol),
)
if request.Version == Version1 {
common.Must(binary.Write(buffer, binary.BigEndian, request.Padding))
if request.Padding {
common.Must(binary.Write(buffer, binary.BigEndian, paddingLen))
buffer.Extend(int(paddingLen))
}
}
common.Must1(buffer.Write(payload))
return buffer
}
const (
flagUDP = 1
flagAddr = 2
statusSuccess = 0
statusError = 1
)
type StreamRequest struct {
Network string
Destination M.Socksaddr
PacketAddr bool
}
func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
var flags uint16
err := binary.Read(reader, binary.BigEndian, &flags)
if err != nil {
return nil, err
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
var network string
var udpAddr bool
if flags&flagUDP == 0 {
network = N.NetworkTCP
} else {
network = N.NetworkUDP
udpAddr = flags&flagAddr != 0
}
return &StreamRequest{network, destination, udpAddr}, nil
}
func streamRequestLen(request StreamRequest) int {
var rLen int
rLen += 1 // version
rLen += 2 // flags
rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
return rLen
}
func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
destination := request.Destination
var flags uint16
if request.Network == N.NetworkUDP {
flags |= flagUDP
}
if request.PacketAddr {
flags |= flagAddr
if !destination.IsValid() {
destination = Destination
}
}
common.Must(
binary.Write(buffer, binary.BigEndian, flags),
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
)
}
type StreamResponse struct {
Status uint8
Message string
}
func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
var response StreamResponse
status, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
response.Status = status
if status == statusError {
response.Message, err = rw.ReadVString(reader)
if err != nil {
return nil, err
}
}
return &response, nil
}

73
protocol_conn.go Normal file
View file

@ -0,0 +1,73 @@
package mux
import (
"io"
"net"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
type protocolConn struct {
net.Conn
request Request
requestWritten bool
}
func newProtocolConn(conn net.Conn, request Request) net.Conn {
writer, isVectorised := bufio.CreateVectorisedWriter(conn)
if isVectorised {
return &vectorisedProtocolConn{
protocolConn{
Conn: conn,
request: request,
},
writer,
}
} else {
return &protocolConn{
Conn: conn,
request: request,
}
}
}
func (c *protocolConn) Write(p []byte) (n int, err error) {
if c.requestWritten {
return c.Conn.Write(p)
}
buffer := EncodeRequest(c.request, p)
n, err = c.Conn.Write(buffer.Bytes())
buffer.Release()
if err == nil {
n--
}
c.requestWritten = true
return n, err
}
func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.requestWritten {
return bufio.ReadFrom0(c, r)
}
return bufio.Copy(c.Conn, r)
}
func (c *protocolConn) Upstream() any {
return c.Conn
}
type vectorisedProtocolConn struct {
protocolConn
writer N.VectorisedWriter
}
func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
if c.requestWritten {
return c.writer.WriteVectorised(buffers)
}
c.requestWritten = true
buffer := EncodeRequest(c.request, nil)
return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
}

80
server.go Normal file
View file

@ -0,0 +1,80 @@
package mux
import (
"context"
"net"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
)
type ServerHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, conn net.Conn, metadata M.Metadata) error {
request, err := ReadRequest(conn)
if err != nil {
return err
}
if request.Padding {
conn = newPaddingConn(conn)
}
session, err := newServerSession(conn, request.Protocol)
if err != nil {
return err
}
var group task.Group
group.Append0(func(ctx context.Context) error {
var stream net.Conn
for {
stream, err = session.Accept()
if err != nil {
return err
}
go newConnection(ctx, handler, logger, stream, metadata)
}
})
group.Cleanup(func() {
session.Close()
})
return group.Run(ctx)
}
func newConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, stream net.Conn, metadata M.Metadata) {
stream = &wrapStream{stream}
request, err := ReadStreamRequest(stream)
if err != nil {
logger.ErrorContext(ctx, err)
return
}
metadata.Destination = request.Destination
if request.Network == N.NetworkTCP {
logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
hErr := handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
stream.Close()
if hErr != nil {
handler.NewError(ctx, hErr)
}
} else {
var packetConn N.PacketConn
if !request.PacketAddr {
logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
} else {
logger.InfoContext(ctx, "inbound multiplex packet connection")
packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
}
hErr := handler.NewPacketConnection(ctx, packetConn, metadata)
stream.Close()
if hErr != nil {
handler.NewError(ctx, hErr)
}
}
}

204
server_conn.go Normal file
View file

@ -0,0 +1,204 @@
package mux
import (
"encoding/binary"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
var _ N.HandshakeConn = (*serverConn)(nil)
type serverConn struct {
N.ExtendedConn
responseWritten bool
}
func (c *serverConn) HandshakeFailure(err error) error {
errMessage := err.Error()
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(statusError),
rw.WriteVString(_buffer, errMessage),
)
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *serverConn) Write(b []byte) (n int, err error) {
if c.responseWritten {
return c.ExtendedConn.Write(b)
}
_buffer := buf.StackNewSize(1 + len(b))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(statusSuccess),
common.Error(buffer.Write(b)),
)
_, err = c.ExtendedConn.Write(buffer.Bytes())
if err != nil {
return
}
c.responseWritten = true
return len(b), nil
}
func (c *serverConn) WriteBuffer(buffer *buf.Buffer) error {
if c.responseWritten {
return c.ExtendedConn.WriteBuffer(buffer)
}
buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *serverConn) FrontHeadroom() int {
if !c.responseWritten {
return 1
}
return 0
}
func (c *serverConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *serverConn) Upstream() any {
return c.ExtendedConn
}
var (
_ N.HandshakeConn = (*serverPacketConn)(nil)
_ N.PacketConn = (*serverPacketConn)(nil)
)
type serverPacketConn struct {
N.ExtendedConn
destination M.Socksaddr
responseWritten bool
}
func (c *serverPacketConn) HandshakeFailure(err error) error {
errMessage := err.Error()
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(statusError),
rw.WriteVString(_buffer, errMessage),
)
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
if err != nil {
return
}
destination = c.destination
return
}
func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
pLen := buffer.Len()
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
if !c.responseWritten {
buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true
}
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *serverPacketConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *serverPacketConn) Upstream() any {
return c.ExtendedConn
}
func (c *serverPacketConn) FrontHeadroom() int {
if !c.responseWritten {
return 3
}
return 2
}
var (
_ N.HandshakeConn = (*serverPacketAddrConn)(nil)
_ N.PacketConn = (*serverPacketAddrConn)(nil)
)
type serverPacketAddrConn struct {
N.ExtendedConn
responseWritten bool
}
func (c *serverPacketAddrConn) HandshakeFailure(err error) error {
errMessage := err.Error()
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(statusError),
rw.WriteVString(_buffer, errMessage),
)
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *serverPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
if err != nil {
return
}
return
}
func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
pLen := buffer.Len()
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination))
if !c.responseWritten {
buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true
}
return c.ExtendedConn.WriteBuffer(buffer)
}
func (c *serverPacketAddrConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *serverPacketAddrConn) Upstream() any {
return c.ExtendedConn
}
func (c *serverPacketAddrConn) FrontHeadroom() int {
if !c.responseWritten {
return 3 + M.MaxSocksaddrLength
}
return 2 + M.MaxSocksaddrLength
}

36
server_default.go Normal file
View file

@ -0,0 +1,36 @@
package mux
import (
"context"
"net"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func HandleConnectionDefault(ctx context.Context, conn net.Conn) error {
return HandleConnection(ctx, (*defaultServerHandler)(nil), logger.NOP(), conn, M.Metadata{})
}
type defaultServerHandler struct{}
func (h *defaultServerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
remoteConn, err := N.SystemDialer.DialContext(ctx, N.NetworkTCP, metadata.Destination)
if err != nil {
return err
}
return bufio.CopyConn(ctx, conn, remoteConn)
}
func (h *defaultServerHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
remoteConn, err := N.SystemDialer.ListenPacket(ctx, metadata.Destination)
if err != nil {
return err
}
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(remoteConn))
}
func (h *defaultServerHandler) NewError(ctx context.Context, err error) {
}

106
session.go Normal file
View file

@ -0,0 +1,106 @@
package mux
import (
"io"
"net"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/smux"
"github.com/hashicorp/yamux"
)
type abstractSession interface {
Open() (net.Conn, error)
Accept() (net.Conn, error)
NumStreams() int
Close() error
IsClosed() bool
CanTakeNewRequest() bool
}
func newClientSession(conn net.Conn, protocol byte) (abstractSession, error) {
switch protocol {
case ProtocolH2Mux:
session, err := newH2MuxClient(conn)
if err != nil {
return nil, err
}
return session, nil
case ProtocolSmux:
client, err := smux.Client(conn, smuxConfig())
if err != nil {
return nil, err
}
return &smuxSession{client}, nil
case ProtocolYAMux:
client, err := yamux.Client(conn, yaMuxConfig())
if err != nil {
return nil, err
}
return &yamuxSession{client}, nil
default:
return nil, E.New("unexpected protocol ", protocol)
}
}
func newServerSession(conn net.Conn, protocol byte) (abstractSession, error) {
switch protocol {
case ProtocolH2Mux:
return newH2MuxServer(conn), nil
case ProtocolSmux:
client, err := smux.Server(conn, smuxConfig())
if err != nil {
return nil, err
}
return &smuxSession{client}, nil
case ProtocolYAMux:
client, err := yamux.Server(conn, yaMuxConfig())
if err != nil {
return nil, err
}
return &yamuxSession{client}, nil
default:
return nil, E.New("unexpected protocol ", protocol)
}
}
var _ abstractSession = (*smuxSession)(nil)
type smuxSession struct {
*smux.Session
}
func (s *smuxSession) Open() (net.Conn, error) {
return s.OpenStream()
}
func (s *smuxSession) Accept() (net.Conn, error) {
return s.AcceptStream()
}
func (s *smuxSession) CanTakeNewRequest() bool {
return true
}
type yamuxSession struct {
*yamux.Session
}
func (y *yamuxSession) CanTakeNewRequest() bool {
return true
}
func smuxConfig() *smux.Config {
config := smux.DefaultConfig()
config.KeepAliveDisabled = true
return config
}
func yaMuxConfig() *yamux.Config {
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
config.StreamCloseTimeout = TCPTimeout
config.StreamOpenTimeout = TCPTimeout
return config
}