mirror of
https://github.com/SagerNet/sing-mux.git
synced 2025-04-01 19:17:36 +03:00
Init commit
This commit is contained in:
commit
b8acf3f145
22 changed files with 2021 additions and 0 deletions
5
.github/update_dependencies.sh
vendored
Executable file
5
.github/update_dependencies.sh
vendored
Executable file
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
PROJECTS=$(dirname "$0")/../..
|
||||
go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD)
|
||||
go mod tidy
|
43
.github/workflows/debug.yml
vendored
Normal file
43
.github/workflows/debug.yml
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
name: Debug build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Debug build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
- name: Add cache to Go proxy
|
||||
run: |
|
||||
version=`git rev-parse HEAD`
|
||||
mkdir build
|
||||
pushd build
|
||||
go mod init build
|
||||
go get -v github.com/sagernet/sing-mux@$version
|
||||
popd
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
41
.github/workflows/lint.yml
vendored
Normal file
41
.github/workflows/lint.yml
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
- name: Cache go module
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
key: go-${{ hashFiles('**/go.sum') }}
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: latest
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/.idea/
|
||||
/vendor/
|
17
.golangci.yml
Normal file
17
.golangci.yml
Normal file
|
@ -0,0 +1,17 @@
|
|||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- gofumpt
|
||||
- govet
|
||||
- gci
|
||||
- staticcheck
|
||||
|
||||
linters-settings:
|
||||
gci:
|
||||
custom-order: true
|
||||
sections:
|
||||
- standard
|
||||
- prefix(github.com/sagernet/)
|
||||
- default
|
||||
staticcheck:
|
||||
go: '1.20'
|
14
LICENSE
Normal file
14
LICENSE
Normal file
|
@ -0,0 +1,14 @@
|
|||
Copyright (C) 2022 by nekohasekai <contact-sagernet@sekai.icu>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
21
Makefile
Normal file
21
Makefile
Normal file
|
@ -0,0 +1,21 @@
|
|||
fmt:
|
||||
@gofumpt -l -w .
|
||||
@gofmt -s -w .
|
||||
@gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" .
|
||||
|
||||
fmt_install:
|
||||
go install -v mvdan.cc/gofumpt@latest
|
||||
go install -v github.com/daixiang0/gci@latest
|
||||
|
||||
lint:
|
||||
GOOS=linux golangci-lint run ./...
|
||||
GOOS=android golangci-lint run ./...
|
||||
GOOS=windows golangci-lint run ./...
|
||||
GOOS=darwin golangci-lint run ./...
|
||||
GOOS=freebsd golangci-lint run ./...
|
||||
|
||||
lint_install:
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
test:
|
||||
go test -v ./...
|
3
README.md
Normal file
3
README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# sing-mux
|
||||
|
||||
Simple multiplex library.
|
183
client.go
Normal file
183
client.go
Normal file
|
@ -0,0 +1,183 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
dialer N.Dialer
|
||||
protocol byte
|
||||
maxConnections int
|
||||
minStreams int
|
||||
maxStreams int
|
||||
padding bool
|
||||
access sync.Mutex
|
||||
connections list.List[abstractSession]
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Dialer N.Dialer
|
||||
Protocol string
|
||||
MaxConnections int
|
||||
MinStreams int
|
||||
MaxStreams int
|
||||
Padding bool
|
||||
}
|
||||
|
||||
func NewClient(options Options) (*Client, error) {
|
||||
client := &Client{
|
||||
dialer: options.Dialer,
|
||||
maxConnections: options.MaxConnections,
|
||||
minStreams: options.MinStreams,
|
||||
maxStreams: options.MaxStreams,
|
||||
padding: options.Padding,
|
||||
}
|
||||
if client.dialer == nil {
|
||||
client.dialer = N.SystemDialer
|
||||
}
|
||||
if client.maxStreams == 0 && client.maxConnections == 0 {
|
||||
client.minStreams = 8
|
||||
}
|
||||
switch options.Protocol {
|
||||
case "", "h2mux":
|
||||
client.protocol = ProtocolH2Mux
|
||||
case "smux":
|
||||
client.protocol = ProtocolSmux
|
||||
case "yamux":
|
||||
client.protocol = ProtocolYAMux
|
||||
default:
|
||||
return nil, E.New("unknown protocol: " + options.Protocol)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
stream, err := c.openStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &clientConn{Conn: stream, destination: destination}, nil
|
||||
case N.NetworkUDP:
|
||||
stream, err := c.openStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
|
||||
default:
|
||||
return nil, E.Extend(N.ErrUnknownNetwork, network)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
stream, err := c.openStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
|
||||
}
|
||||
|
||||
func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
|
||||
var (
|
||||
session abstractSession
|
||||
stream net.Conn
|
||||
err error
|
||||
)
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
session, err = c.offer(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
stream, err = session.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wrapStream{stream}, nil
|
||||
}
|
||||
|
||||
func (c *Client) offer(ctx context.Context) (abstractSession, error) {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
sessions := make([]abstractSession, 0, c.maxConnections)
|
||||
for element := c.connections.Front(); element != nil; {
|
||||
if element.Value.IsClosed() {
|
||||
nextElement := element.Next()
|
||||
c.connections.Remove(element)
|
||||
element = nextElement
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, element.Value)
|
||||
element = element.Next()
|
||||
}
|
||||
session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams)
|
||||
if session == nil {
|
||||
return c.offerNew(ctx)
|
||||
}
|
||||
numStreams := session.NumStreams()
|
||||
if numStreams == 0 {
|
||||
return session, nil
|
||||
}
|
||||
if c.maxConnections > 0 {
|
||||
if len(sessions) >= c.maxConnections || numStreams < c.minStreams {
|
||||
return session, nil
|
||||
}
|
||||
} else {
|
||||
if c.maxStreams > 0 && numStreams < c.maxStreams {
|
||||
return session, nil
|
||||
}
|
||||
}
|
||||
return c.offerNew(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var version byte
|
||||
if c.padding {
|
||||
version = Version1
|
||||
} else {
|
||||
version = Version0
|
||||
}
|
||||
conn = newProtocolConn(conn, Request{
|
||||
Version: version,
|
||||
Protocol: c.protocol,
|
||||
Padding: c.padding,
|
||||
})
|
||||
if c.padding {
|
||||
conn = newPaddingConn(conn)
|
||||
}
|
||||
session, err := newClientSession(conn, c.protocol)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
c.connections.PushBack(session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *Client) Reset() {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
for _, session := range c.connections.Array() {
|
||||
session.Close()
|
||||
}
|
||||
c.connections.Init()
|
||||
}
|
380
client_conn.go
Normal file
380
client_conn.go
Normal file
|
@ -0,0 +1,380 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type clientConn struct {
|
||||
net.Conn
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
}
|
||||
|
||||
func (c *clientConn) readResponse() error {
|
||||
response, err := ReadStreamResponse(c.Conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.Status == statusError {
|
||||
return E.New("remote error: ", response.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConn) Read(b []byte) (n int, err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *clientConn) Write(b []byte) (n int, err error) {
|
||||
if c.requestWritten {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
request := StreamRequest{
|
||||
Network: N.NetworkTCP,
|
||||
Destination: c.destination,
|
||||
}
|
||||
_buffer := buf.StackNewSize(streamRequestLen(request) + len(b))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeStreamRequest(request, buffer)
|
||||
buffer.Write(b)
|
||||
_, err = c.Conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.requestWritten = true
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !c.requestWritten {
|
||||
return bufio.ReadFrom0(c, r)
|
||||
}
|
||||
return bufio.Copy(c.Conn, r)
|
||||
}
|
||||
|
||||
func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if !c.responseRead {
|
||||
return bufio.WriteTo0(c, w)
|
||||
}
|
||||
return bufio.Copy(w, c.Conn)
|
||||
}
|
||||
|
||||
func (c *clientConn) LocalAddr() net.Addr {
|
||||
return c.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *clientConn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
func (c *clientConn) ReaderReplaceable() bool {
|
||||
return c.responseRead
|
||||
}
|
||||
|
||||
func (c *clientConn) WriterReplaceable() bool {
|
||||
return c.requestWritten
|
||||
}
|
||||
|
||||
func (c *clientConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *clientConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
N.ExtendedConn
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) readResponse() error {
|
||||
response, err := ReadStreamResponse(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.Status == statusError {
|
||||
return E.New("remote error: ", response.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) Read(b []byte) (n int, err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cap(b) < int(length) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
return io.ReadFull(c.ExtendedConn, b[:length])
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
|
||||
request := StreamRequest{
|
||||
Network: N.NetworkUDP,
|
||||
Destination: c.destination,
|
||||
}
|
||||
rLen := streamRequestLen(request)
|
||||
if len(payload) > 0 {
|
||||
rLen += 2 + len(payload)
|
||||
}
|
||||
_buffer := buf.StackNewSize(rLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeStreamRequest(request, buffer)
|
||||
if len(payload) > 0 {
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
|
||||
common.Error(buffer.Write(payload)),
|
||||
)
|
||||
}
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.requestWritten = true
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) Write(b []byte) (n int, err error) {
|
||||
if !c.requestWritten {
|
||||
return c.writeRequest(b)
|
||||
}
|
||||
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(b)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if !c.requestWritten {
|
||||
defer buffer.Release()
|
||||
return common.Error(c.writeRequest(buffer.Bytes()))
|
||||
}
|
||||
bLen := buffer.Len()
|
||||
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) FrontHeadroom() int {
|
||||
return 2
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
err = c.ReadBuffer(buffer)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
return c.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) LocalAddr() net.Addr {
|
||||
return c.ExtendedConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
|
||||
|
||||
type clientPacketAddrConn struct {
|
||||
N.ExtendedConn
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) readResponse() error {
|
||||
response, err := ReadStreamResponse(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response.Status == statusError {
|
||||
return E.New("remote error: ", response.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if destination.IsFqdn() {
|
||||
addr = destination
|
||||
} else {
|
||||
addr = destination.UDPAddr()
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cap(p) < int(length) {
|
||||
return 0, nil, io.ErrShortBuffer
|
||||
}
|
||||
n, err = io.ReadFull(c.ExtendedConn, p[:length])
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
|
||||
request := StreamRequest{
|
||||
Network: N.NetworkUDP,
|
||||
Destination: c.destination,
|
||||
PacketAddr: true,
|
||||
}
|
||||
rLen := streamRequestLen(request)
|
||||
if len(payload) > 0 {
|
||||
rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
|
||||
}
|
||||
_buffer := buf.StackNewSize(rLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
EncodeStreamRequest(request, buffer)
|
||||
if len(payload) > 0 {
|
||||
common.Must(
|
||||
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
|
||||
common.Error(buffer.Write(payload)),
|
||||
)
|
||||
}
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.requestWritten = true
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if !c.requestWritten {
|
||||
return c.writeRequest(p, M.SocksaddrFromNet(addr))
|
||||
}
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(p)
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if !c.responseRead {
|
||||
err = c.readResponse()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if !c.requestWritten {
|
||||
defer buffer.Release()
|
||||
return common.Error(c.writeRequest(buffer.Bytes(), destination))
|
||||
}
|
||||
bLen := buffer.Len()
|
||||
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
|
||||
common.Must(
|
||||
M.SocksaddrSerializer.WriteAddrPort(header, destination),
|
||||
binary.Write(header, binary.BigEndian, uint16(bLen)),
|
||||
)
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) LocalAddr() net.Addr {
|
||||
return c.ExtendedConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) FrontHeadroom() int {
|
||||
return 2 + M.MaxSocksaddrLength
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *clientPacketAddrConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
37
error.go
Normal file
37
error.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
type wrapStream struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (w *wrapStream) Read(p []byte) (n int, err error) {
|
||||
n, err = w.Conn.Read(p)
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func (w *wrapStream) Write(p []byte) (n int, err error) {
|
||||
n, err = w.Conn.Write(p)
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func (w *wrapStream) Upstream() any {
|
||||
return w.Conn
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
switch err {
|
||||
case yamux.ErrStreamClosed:
|
||||
return io.EOF
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
15
go.mod
Normal file
15
go.mod
Normal file
|
@ -0,0 +1,15 @@
|
|||
module github.com/sagernet/sing-mux
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/hashicorp/yamux v0.1.1
|
||||
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207
|
||||
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
|
||||
golang.org/x/net v0.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
golang.org/x/sys v0.7.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
)
|
14
go.sum
Normal file
14
go.sum
Normal file
|
@ -0,0 +1,14 @@
|
|||
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
|
||||
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
|
||||
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
|
||||
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss=
|
||||
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
|
||||
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
|
||||
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
|
||||
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
242
h2mux.go
Normal file
242
h2mux.go
Normal file
|
@ -0,0 +1,242 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const idleTimeout = 30 * time.Second
|
||||
|
||||
var _ abstractSession = (*h2MuxServerSession)(nil)
|
||||
|
||||
type h2MuxServerSession struct {
|
||||
server http2.Server
|
||||
active atomic.Int32
|
||||
conn net.Conn
|
||||
inbound chan net.Conn
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newH2MuxServer(conn net.Conn) *h2MuxServerSession {
|
||||
session := &h2MuxServerSession{
|
||||
conn: conn,
|
||||
inbound: make(chan net.Conn),
|
||||
done: make(chan struct{}),
|
||||
server: http2.Server{
|
||||
IdleTimeout: idleTimeout,
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
session.server.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Handler: session,
|
||||
})
|
||||
_ = session.Close()
|
||||
}()
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
s.active.Add(1)
|
||||
defer s.active.Add(-1)
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
conn := newHTTP2Wrapper(newHTTPConn(request.Body, writer), writer.(http.Flusher))
|
||||
s.inbound <- conn
|
||||
select {
|
||||
case <-conn.done:
|
||||
case <-s.done:
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) Open() (net.Conn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-s.inbound:
|
||||
return conn, nil
|
||||
case <-s.done:
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) NumStreams() int {
|
||||
return int(s.active.Load())
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) Close() error {
|
||||
select {
|
||||
case <-s.done:
|
||||
default:
|
||||
close(s.done)
|
||||
}
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) IsClosed() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *h2MuxServerSession) CanTakeNewRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type h2MuxConnWrapper struct {
|
||||
N.ExtendedConn
|
||||
flusher http.Flusher
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newHTTP2Wrapper(conn net.Conn, flusher http.Flusher) *h2MuxConnWrapper {
|
||||
return &h2MuxConnWrapper{
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
flusher: flusher,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *h2MuxConnWrapper) Write(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-w.done:
|
||||
return 0, net.ErrClosed
|
||||
default:
|
||||
}
|
||||
n, err = w.ExtendedConn.Write(p)
|
||||
if err == nil {
|
||||
w.flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (w *h2MuxConnWrapper) WriteBuffer(buffer *buf.Buffer) error {
|
||||
select {
|
||||
case <-w.done:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
}
|
||||
err := w.ExtendedConn.WriteBuffer(buffer)
|
||||
if err == nil {
|
||||
w.flusher.Flush()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *h2MuxConnWrapper) Close() error {
|
||||
select {
|
||||
case <-w.done:
|
||||
default:
|
||||
close(w.done)
|
||||
}
|
||||
return w.ExtendedConn.Close()
|
||||
}
|
||||
|
||||
func (w *h2MuxConnWrapper) Upstream() any {
|
||||
return w.ExtendedConn
|
||||
}
|
||||
|
||||
var _ abstractSession = (*h2MuxClientSession)(nil)
|
||||
|
||||
type h2MuxClientSession struct {
|
||||
transport *http2.Transport
|
||||
clientConn *http2.ClientConn
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newH2MuxClient(conn net.Conn) (*h2MuxClientSession, error) {
|
||||
session := &h2MuxClientSession{
|
||||
transport: &http2.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return conn, nil
|
||||
},
|
||||
ReadIdleTimeout: idleTimeout,
|
||||
},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
session.transport.ConnPool = session
|
||||
clientConn, err := session.transport.NewClientConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session.clientConn = clientConn
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) {
|
||||
return s.clientConn, nil
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) {
|
||||
s.Close()
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) Open() (net.Conn, error) {
|
||||
pipeInReader, pipeInWriter := io.Pipe()
|
||||
request := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
Body: pipeInReader,
|
||||
URL: &url.URL{Scheme: "https", Host: "localhost"},
|
||||
}
|
||||
conn := newLateHTTPConn(pipeInWriter)
|
||||
go func() {
|
||||
response, err := s.transport.RoundTrip(request)
|
||||
if err != nil {
|
||||
conn.setup(nil, err)
|
||||
} else if response.StatusCode != 200 {
|
||||
response.Body.Close()
|
||||
conn.setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status))
|
||||
} else {
|
||||
conn.setup(response.Body, nil)
|
||||
}
|
||||
}()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) Accept() (net.Conn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) NumStreams() int {
|
||||
return s.clientConn.State().StreamsActive
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) Close() error {
|
||||
select {
|
||||
case <-s.done:
|
||||
default:
|
||||
close(s.done)
|
||||
}
|
||||
return s.clientConn.Close()
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) IsClosed() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
}
|
||||
return s.clientConn.State().Closed
|
||||
}
|
||||
|
||||
func (s *h2MuxClientSession) CanTakeNewRequest() bool {
|
||||
return s.clientConn.CanTakeNewRequest()
|
||||
}
|
82
h2mux_conn.go
Normal file
82
h2mux_conn.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/baderror"
|
||||
)
|
||||
|
||||
type httpConn struct {
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
create chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newHTTPConn(reader io.Reader, writer io.Writer) *httpConn {
|
||||
return &httpConn{
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
func newLateHTTPConn(writer io.Writer) *httpConn {
|
||||
return &httpConn{
|
||||
create: make(chan struct{}),
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *httpConn) setup(reader io.Reader, err error) {
|
||||
c.reader = reader
|
||||
c.err = err
|
||||
close(c.create)
|
||||
}
|
||||
|
||||
func (c *httpConn) Read(b []byte) (n int, err error) {
|
||||
if c.reader == nil {
|
||||
<-c.create
|
||||
if c.err != nil {
|
||||
return 0, c.err
|
||||
}
|
||||
}
|
||||
n, err = c.reader.Read(b)
|
||||
return n, baderror.WrapH2(err)
|
||||
}
|
||||
|
||||
func (c *httpConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.writer.Write(b)
|
||||
return n, baderror.WrapH2(err)
|
||||
}
|
||||
|
||||
func (c *httpConn) Close() error {
|
||||
return common.Close(c.reader, c.writer)
|
||||
}
|
||||
|
||||
func (c *httpConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *httpConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *httpConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *httpConn) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *httpConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *httpConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
240
padding.go
Normal file
240
padding.go
Normal file
|
@ -0,0 +1,240 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
const kFirstPaddings = 16
|
||||
|
||||
type paddingConn struct {
|
||||
N.ExtendedConn
|
||||
writer N.VectorisedWriter
|
||||
readPadding int
|
||||
writePadding int
|
||||
readRemaining int
|
||||
paddingRemaining int
|
||||
}
|
||||
|
||||
func newPaddingConn(conn net.Conn) net.Conn {
|
||||
writer, isVectorised := bufio.CreateVectorisedWriter(conn)
|
||||
if isVectorised {
|
||||
return &vectorisedPaddingConn{
|
||||
paddingConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
writer: bufio.NewVectorisedWriter(conn),
|
||||
},
|
||||
writer,
|
||||
}
|
||||
} else {
|
||||
return &paddingConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
writer: bufio.NewVectorisedWriter(conn),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *paddingConn) Read(p []byte) (n int, err error) {
|
||||
if c.readRemaining > 0 {
|
||||
if len(p) > c.readRemaining {
|
||||
p = p[:c.readRemaining]
|
||||
}
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.readRemaining -= n
|
||||
return
|
||||
}
|
||||
if c.paddingRemaining > 0 {
|
||||
err = rw.SkipN(c.ExtendedConn, c.paddingRemaining)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.paddingRemaining = 0
|
||||
}
|
||||
if c.readPadding < kFirstPaddings {
|
||||
var paddingHdr []byte
|
||||
if len(p) >= 4 {
|
||||
paddingHdr = p[:4]
|
||||
} else {
|
||||
_paddingHdr := make([]byte, 4)
|
||||
defer common.KeepAlive(_paddingHdr)
|
||||
paddingHdr = common.Dup(_paddingHdr)
|
||||
}
|
||||
_, err = io.ReadFull(c.ExtendedConn, paddingHdr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
|
||||
paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
|
||||
if len(p) > originalDataSize {
|
||||
p = p[:originalDataSize]
|
||||
}
|
||||
n, err = c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.readPadding++
|
||||
c.readRemaining = originalDataSize - n
|
||||
c.paddingRemaining = paddingLen
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Read(p)
|
||||
}
|
||||
|
||||
func (c *paddingConn) Write(p []byte) (n int, err error) {
|
||||
for pLen := len(p); pLen > 0; {
|
||||
var data []byte
|
||||
if pLen > 65535 {
|
||||
data = p[:65535]
|
||||
p = p[65535:]
|
||||
pLen -= 65535
|
||||
} else {
|
||||
data = p
|
||||
pLen = 0
|
||||
}
|
||||
var writeN int
|
||||
writeN, err = c.write(data)
|
||||
n += writeN
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *paddingConn) write(p []byte) (n int, err error) {
|
||||
if c.writePadding < kFirstPaddings {
|
||||
paddingLen := 256 + rand.Intn(512)
|
||||
_buffer := buf.StackNewSize(4 + len(p) + paddingLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
header := buffer.Extend(4)
|
||||
binary.BigEndian.PutUint16(header[:2], uint16(len(p)))
|
||||
binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
|
||||
common.Must1(buffer.Write(p))
|
||||
buffer.Extend(paddingLen)
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
if err == nil {
|
||||
n = len(p)
|
||||
}
|
||||
c.writePadding++
|
||||
return
|
||||
}
|
||||
return c.ExtendedConn.Write(p)
|
||||
}
|
||||
|
||||
func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
p := buffer.FreeBytes()
|
||||
if c.readRemaining > 0 {
|
||||
if len(p) > c.readRemaining {
|
||||
p = p[:c.readRemaining]
|
||||
}
|
||||
n, err := c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.readRemaining -= n
|
||||
buffer.Truncate(n)
|
||||
return nil
|
||||
}
|
||||
if c.paddingRemaining > 0 {
|
||||
err := rw.SkipN(c.ExtendedConn, c.paddingRemaining)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.paddingRemaining = 0
|
||||
}
|
||||
if c.readPadding < kFirstPaddings {
|
||||
var paddingHdr []byte
|
||||
if len(p) >= 4 {
|
||||
paddingHdr = p[:4]
|
||||
} else {
|
||||
_paddingHdr := make([]byte, 4)
|
||||
defer common.KeepAlive(_paddingHdr)
|
||||
paddingHdr = common.Dup(_paddingHdr)
|
||||
}
|
||||
_, err := io.ReadFull(c.ExtendedConn, paddingHdr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
|
||||
paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
|
||||
|
||||
if len(p) > originalDataSize {
|
||||
p = p[:originalDataSize]
|
||||
}
|
||||
n, err := c.ExtendedConn.Read(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.readPadding++
|
||||
c.readRemaining = originalDataSize - n
|
||||
c.paddingRemaining = paddingLen
|
||||
buffer.Truncate(n)
|
||||
return nil
|
||||
}
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if c.writePadding < kFirstPaddings {
|
||||
bufferLen := buffer.Len()
|
||||
if bufferLen > 65535 {
|
||||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
paddingLen := 256 + rand.Intn(512)
|
||||
header := buffer.ExtendHeader(4)
|
||||
binary.BigEndian.PutUint16(header[:2], uint16(bufferLen))
|
||||
binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
|
||||
buffer.Extend(paddingLen)
|
||||
c.writePadding++
|
||||
}
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *paddingConn) FrontHeadroom() int {
|
||||
return 4 + 256 + 1024
|
||||
}
|
||||
|
||||
type vectorisedPaddingConn struct {
|
||||
paddingConn
|
||||
writer N.VectorisedWriter
|
||||
}
|
||||
|
||||
func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
if c.writePadding < kFirstPaddings {
|
||||
bufferLen := buf.LenMulti(buffers)
|
||||
if bufferLen > 65535 {
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
for _, buffer := range buffers {
|
||||
_, err := c.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
paddingLen := 256 + rand.Intn(512)
|
||||
header := buf.NewSize(4)
|
||||
common.Must(
|
||||
binary.Write(header, binary.BigEndian, uint16(bufferLen)),
|
||||
binary.Write(header, binary.BigEndian, uint16(paddingLen)),
|
||||
)
|
||||
c.writePadding++
|
||||
padding := buf.NewSize(paddingLen)
|
||||
padding.Extend(paddingLen)
|
||||
buffers = append(append([]*buf.Buffer{header}, buffers...), padding)
|
||||
}
|
||||
return c.writer.WriteVectorised(buffers)
|
||||
}
|
183
protocol.go
Normal file
183
protocol.go
Normal file
|
@ -0,0 +1,183 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolSmux = iota
|
||||
ProtocolYAMux
|
||||
ProtocolH2Mux
|
||||
)
|
||||
|
||||
const (
|
||||
Version0 = iota
|
||||
Version1
|
||||
)
|
||||
|
||||
const (
|
||||
TCPTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
var Destination = M.Socksaddr{
|
||||
Fqdn: "sp.mux.sing-box.arpa",
|
||||
Port: 444,
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Version byte
|
||||
Protocol byte
|
||||
Padding bool
|
||||
}
|
||||
|
||||
func ReadRequest(reader io.Reader) (*Request, error) {
|
||||
version, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version < Version0 || version > Version1 {
|
||||
return nil, E.New("unsupported version: ", version)
|
||||
}
|
||||
protocol, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var paddingEnabled bool
|
||||
if version == Version1 {
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingEnabled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paddingEnabled {
|
||||
var paddingLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &paddingLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = rw.SkipN(reader, int(paddingLen))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Request{Version: version, Protocol: protocol, Padding: paddingEnabled}, nil
|
||||
}
|
||||
|
||||
func EncodeRequest(request Request, payload []byte) *buf.Buffer {
|
||||
var requestLen int
|
||||
requestLen += 2
|
||||
var paddingLen uint16
|
||||
if request.Version == Version1 {
|
||||
requestLen += 1
|
||||
if request.Padding {
|
||||
requestLen += 2
|
||||
paddingLen = uint16(256 + rand.Intn(512))
|
||||
requestLen += int(paddingLen)
|
||||
}
|
||||
}
|
||||
buffer := buf.NewSize(requestLen + len(payload))
|
||||
common.Must(
|
||||
buffer.WriteByte(request.Version),
|
||||
buffer.WriteByte(request.Protocol),
|
||||
)
|
||||
if request.Version == Version1 {
|
||||
common.Must(binary.Write(buffer, binary.BigEndian, request.Padding))
|
||||
if request.Padding {
|
||||
common.Must(binary.Write(buffer, binary.BigEndian, paddingLen))
|
||||
buffer.Extend(int(paddingLen))
|
||||
}
|
||||
}
|
||||
common.Must1(buffer.Write(payload))
|
||||
return buffer
|
||||
}
|
||||
|
||||
const (
|
||||
flagUDP = 1
|
||||
flagAddr = 2
|
||||
statusSuccess = 0
|
||||
statusError = 1
|
||||
)
|
||||
|
||||
type StreamRequest struct {
|
||||
Network string
|
||||
Destination M.Socksaddr
|
||||
PacketAddr bool
|
||||
}
|
||||
|
||||
func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
|
||||
var flags uint16
|
||||
err := binary.Read(reader, binary.BigEndian, &flags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var network string
|
||||
var udpAddr bool
|
||||
if flags&flagUDP == 0 {
|
||||
network = N.NetworkTCP
|
||||
} else {
|
||||
network = N.NetworkUDP
|
||||
udpAddr = flags&flagAddr != 0
|
||||
}
|
||||
return &StreamRequest{network, destination, udpAddr}, nil
|
||||
}
|
||||
|
||||
func streamRequestLen(request StreamRequest) int {
|
||||
var rLen int
|
||||
rLen += 1 // version
|
||||
rLen += 2 // flags
|
||||
rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
|
||||
return rLen
|
||||
}
|
||||
|
||||
func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
|
||||
destination := request.Destination
|
||||
var flags uint16
|
||||
if request.Network == N.NetworkUDP {
|
||||
flags |= flagUDP
|
||||
}
|
||||
if request.PacketAddr {
|
||||
flags |= flagAddr
|
||||
if !destination.IsValid() {
|
||||
destination = Destination
|
||||
}
|
||||
}
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, flags),
|
||||
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
|
||||
)
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Status uint8
|
||||
Message string
|
||||
}
|
||||
|
||||
func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
|
||||
var response StreamResponse
|
||||
status, err := rw.ReadByte(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Status = status
|
||||
if status == statusError {
|
||||
response.Message, err = rw.ReadVString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &response, nil
|
||||
}
|
73
protocol_conn.go
Normal file
73
protocol_conn.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type protocolConn struct {
|
||||
net.Conn
|
||||
request Request
|
||||
requestWritten bool
|
||||
}
|
||||
|
||||
func newProtocolConn(conn net.Conn, request Request) net.Conn {
|
||||
writer, isVectorised := bufio.CreateVectorisedWriter(conn)
|
||||
if isVectorised {
|
||||
return &vectorisedProtocolConn{
|
||||
protocolConn{
|
||||
Conn: conn,
|
||||
request: request,
|
||||
},
|
||||
writer,
|
||||
}
|
||||
} else {
|
||||
return &protocolConn{
|
||||
Conn: conn,
|
||||
request: request,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *protocolConn) Write(p []byte) (n int, err error) {
|
||||
if c.requestWritten {
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
buffer := EncodeRequest(c.request, p)
|
||||
n, err = c.Conn.Write(buffer.Bytes())
|
||||
buffer.Release()
|
||||
if err == nil {
|
||||
n--
|
||||
}
|
||||
c.requestWritten = true
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !c.requestWritten {
|
||||
return bufio.ReadFrom0(c, r)
|
||||
}
|
||||
return bufio.Copy(c.Conn, r)
|
||||
}
|
||||
|
||||
func (c *protocolConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
type vectorisedProtocolConn struct {
|
||||
protocolConn
|
||||
writer N.VectorisedWriter
|
||||
}
|
||||
|
||||
func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
if c.requestWritten {
|
||||
return c.writer.WriteVectorised(buffers)
|
||||
}
|
||||
c.requestWritten = true
|
||||
buffer := EncodeRequest(c.request, nil)
|
||||
return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
|
||||
}
|
80
server.go
Normal file
80
server.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
type ServerHandler interface {
|
||||
N.TCPConnectionHandler
|
||||
N.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
func HandleConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, conn net.Conn, metadata M.Metadata) error {
|
||||
request, err := ReadRequest(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.Padding {
|
||||
conn = newPaddingConn(conn)
|
||||
}
|
||||
session, err := newServerSession(conn, request.Protocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var group task.Group
|
||||
group.Append0(func(ctx context.Context) error {
|
||||
var stream net.Conn
|
||||
for {
|
||||
stream, err = session.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go newConnection(ctx, handler, logger, stream, metadata)
|
||||
}
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
session.Close()
|
||||
})
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func newConnection(ctx context.Context, handler ServerHandler, logger logger.ContextLogger, stream net.Conn, metadata M.Metadata) {
|
||||
stream = &wrapStream{stream}
|
||||
request, err := ReadStreamRequest(stream)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err)
|
||||
return
|
||||
}
|
||||
metadata.Destination = request.Destination
|
||||
if request.Network == N.NetworkTCP {
|
||||
logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
|
||||
hErr := handler.NewConnection(ctx, &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
|
||||
stream.Close()
|
||||
if hErr != nil {
|
||||
handler.NewError(ctx, hErr)
|
||||
}
|
||||
} else {
|
||||
var packetConn N.PacketConn
|
||||
if !request.PacketAddr {
|
||||
logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
|
||||
packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
|
||||
} else {
|
||||
logger.InfoContext(ctx, "inbound multiplex packet connection")
|
||||
packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
|
||||
}
|
||||
hErr := handler.NewPacketConnection(ctx, packetConn, metadata)
|
||||
stream.Close()
|
||||
if hErr != nil {
|
||||
handler.NewError(ctx, hErr)
|
||||
}
|
||||
}
|
||||
}
|
204
server_conn.go
Normal file
204
server_conn.go
Normal file
|
@ -0,0 +1,204 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
var _ N.HandshakeConn = (*serverConn)(nil)
|
||||
|
||||
type serverConn struct {
|
||||
N.ExtendedConn
|
||||
responseWritten bool
|
||||
}
|
||||
|
||||
func (c *serverConn) HandshakeFailure(err error) error {
|
||||
errMessage := err.Error()
|
||||
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(statusError),
|
||||
rw.WriteVString(_buffer, errMessage),
|
||||
)
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
if c.responseWritten {
|
||||
return c.ExtendedConn.Write(b)
|
||||
}
|
||||
_buffer := buf.StackNewSize(1 + len(b))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(statusSuccess),
|
||||
common.Error(buffer.Write(b)),
|
||||
)
|
||||
_, err = c.ExtendedConn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.responseWritten = true
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *serverConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if c.responseWritten {
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
buffer.ExtendHeader(1)[0] = statusSuccess
|
||||
c.responseWritten = true
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *serverConn) FrontHeadroom() int {
|
||||
if !c.responseWritten {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *serverConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *serverConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
var (
|
||||
_ N.HandshakeConn = (*serverPacketConn)(nil)
|
||||
_ N.PacketConn = (*serverPacketConn)(nil)
|
||||
)
|
||||
|
||||
type serverPacketConn struct {
|
||||
N.ExtendedConn
|
||||
destination M.Socksaddr
|
||||
responseWritten bool
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) HandshakeFailure(err error) error {
|
||||
errMessage := err.Error()
|
||||
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(statusError),
|
||||
rw.WriteVString(_buffer, errMessage),
|
||||
)
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
destination = c.destination
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
pLen := buffer.Len()
|
||||
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
|
||||
if !c.responseWritten {
|
||||
buffer.ExtendHeader(1)[0] = statusSuccess
|
||||
c.responseWritten = true
|
||||
}
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func (c *serverPacketConn) FrontHeadroom() int {
|
||||
if !c.responseWritten {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
var (
|
||||
_ N.HandshakeConn = (*serverPacketAddrConn)(nil)
|
||||
_ N.PacketConn = (*serverPacketAddrConn)(nil)
|
||||
)
|
||||
|
||||
type serverPacketAddrConn struct {
|
||||
N.ExtendedConn
|
||||
responseWritten bool
|
||||
}
|
||||
|
||||
func (c *serverPacketAddrConn) HandshakeFailure(err error) error {
|
||||
errMessage := err.Error()
|
||||
_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
common.Must(
|
||||
buffer.WriteByte(statusError),
|
||||
rw.WriteVString(_buffer, errMessage),
|
||||
)
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *serverPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var length uint16
|
||||
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
pLen := buffer.Len()
|
||||
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
|
||||
common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination))
|
||||
if !c.responseWritten {
|
||||
buffer.ExtendHeader(1)[0] = statusSuccess
|
||||
c.responseWritten = true
|
||||
}
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *serverPacketAddrConn) NeedAdditionalReadDeadline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *serverPacketAddrConn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
func (c *serverPacketAddrConn) FrontHeadroom() int {
|
||||
if !c.responseWritten {
|
||||
return 3 + M.MaxSocksaddrLength
|
||||
}
|
||||
return 2 + M.MaxSocksaddrLength
|
||||
}
|
36
server_default.go
Normal file
36
server_default.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func HandleConnectionDefault(ctx context.Context, conn net.Conn) error {
|
||||
return HandleConnection(ctx, (*defaultServerHandler)(nil), logger.NOP(), conn, M.Metadata{})
|
||||
}
|
||||
|
||||
type defaultServerHandler struct{}
|
||||
|
||||
func (h *defaultServerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
remoteConn, err := N.SystemDialer.DialContext(ctx, N.NetworkTCP, metadata.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bufio.CopyConn(ctx, conn, remoteConn)
|
||||
}
|
||||
|
||||
func (h *defaultServerHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
|
||||
remoteConn, err := N.SystemDialer.ListenPacket(ctx, metadata.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(remoteConn))
|
||||
}
|
||||
|
||||
func (h *defaultServerHandler) NewError(ctx context.Context, err error) {
|
||||
}
|
106
session.go
Normal file
106
session.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/smux"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
type abstractSession interface {
|
||||
Open() (net.Conn, error)
|
||||
Accept() (net.Conn, error)
|
||||
NumStreams() int
|
||||
Close() error
|
||||
IsClosed() bool
|
||||
CanTakeNewRequest() bool
|
||||
}
|
||||
|
||||
func newClientSession(conn net.Conn, protocol byte) (abstractSession, error) {
|
||||
switch protocol {
|
||||
case ProtocolH2Mux:
|
||||
session, err := newH2MuxClient(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
case ProtocolSmux:
|
||||
client, err := smux.Client(conn, smuxConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &smuxSession{client}, nil
|
||||
case ProtocolYAMux:
|
||||
client, err := yamux.Client(conn, yaMuxConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &yamuxSession{client}, nil
|
||||
default:
|
||||
return nil, E.New("unexpected protocol ", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func newServerSession(conn net.Conn, protocol byte) (abstractSession, error) {
|
||||
switch protocol {
|
||||
case ProtocolH2Mux:
|
||||
return newH2MuxServer(conn), nil
|
||||
case ProtocolSmux:
|
||||
client, err := smux.Server(conn, smuxConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &smuxSession{client}, nil
|
||||
case ProtocolYAMux:
|
||||
client, err := yamux.Server(conn, yaMuxConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &yamuxSession{client}, nil
|
||||
default:
|
||||
return nil, E.New("unexpected protocol ", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
var _ abstractSession = (*smuxSession)(nil)
|
||||
|
||||
type smuxSession struct {
|
||||
*smux.Session
|
||||
}
|
||||
|
||||
func (s *smuxSession) Open() (net.Conn, error) {
|
||||
return s.OpenStream()
|
||||
}
|
||||
|
||||
func (s *smuxSession) Accept() (net.Conn, error) {
|
||||
return s.AcceptStream()
|
||||
}
|
||||
|
||||
func (s *smuxSession) CanTakeNewRequest() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type yamuxSession struct {
|
||||
*yamux.Session
|
||||
}
|
||||
|
||||
func (y *yamuxSession) CanTakeNewRequest() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func smuxConfig() *smux.Config {
|
||||
config := smux.DefaultConfig()
|
||||
config.KeepAliveDisabled = true
|
||||
return config
|
||||
}
|
||||
|
||||
func yaMuxConfig() *yamux.Config {
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
config.StreamCloseTimeout = TCPTimeout
|
||||
config.StreamOpenTimeout = TCPTimeout
|
||||
return config
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue