refactor: proxymux

This commit rewrites proxymux package to provide following functions:

+ proxymux.ListenSOCKS(address string)
+ proxymux.ListenHTTP(address string)

both are drop-in replacements for net.Listen("tcp", address)

The above functions can be called with the same address to take
advantage of the mux feature.

Tests are not included, but we will have them very soon.

This commit should be in PR #1006, but I ended up with it in a separate
branch here. Please rebase if you want to merge it.
This commit is contained in:
Haruue 2024-04-11 20:53:28 +08:00
parent d34ff757c3
commit 34574e0339
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148
5 changed files with 301 additions and 145 deletions

View file

@ -0,0 +1,72 @@
package proxymux
import (
"net"
"sync"
"github.com/apernet/hysteria/extras/correctnet"
)
type muxManager struct {
listeners map[string]*muxListener
lock sync.Mutex
}
var globalMuxManager *muxManager
func init() {
globalMuxManager = &muxManager{
listeners: make(map[string]*muxListener),
}
}
func (m *muxManager) GetOrCreate(address string) (*muxListener, error) {
key, err := m.canonicalizeAddrPort(address)
if err != nil {
return nil, err
}
m.lock.Lock()
defer m.lock.Unlock()
if ml, ok := m.listeners[key]; ok {
return ml, nil
}
listener, err := correctnet.Listen("tcp", key)
if err != nil {
return nil, err
}
ml := newMuxListener(listener, func() {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.listeners, key)
})
m.listeners[key] = ml
return ml, nil
}
func (m *muxManager) canonicalizeAddrPort(address string) (string, error) {
taddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return "", err
}
return taddr.String(), nil
}
func ListenHTTP(address string) (net.Listener, error) {
ml, err := globalMuxManager.GetOrCreate(address)
if err != nil {
return nil, err
}
return ml.ListenHTTP()
}
func ListenSOCKS(address string) (net.Listener, error) {
ml, err := globalMuxManager.GetOrCreate(address)
if err != nil {
return nil, err
}
return ml.ListenSOCKS()
}

View file

@ -1,124 +1,257 @@
// Package proxymux splits a net.Listener in two, routing SOCKS5
// connections to one and HTTP requests to the other.
//
// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the
// same listener.
package proxymux
import (
"errors"
"fmt"
"io"
"net"
"sync"
"time"
)
// SplitSOCKSAndHTTP accepts connections on ln and passes connections
// through to either socksListener or httpListener, depending the
// first byte sent by the client.
func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) {
sl := &listener{
addr: ln.Addr(),
c: make(chan net.Conn),
closed: make(chan struct{}),
func newMuxListener(listener net.Listener, deleteFunc func()) *muxListener {
l := &muxListener{
base: listener,
acceptChan: make(chan net.Conn),
closeChan: make(chan struct{}),
deleteFunc: deleteFunc,
}
hl := &listener{
addr: ln.Addr(),
c: make(chan net.Conn),
closed: make(chan struct{}),
}
go splitSOCKSAndHTTPListener(ln, sl, hl)
return sl, hl
go l.acceptLoop()
go l.mainLoop()
return l
}
func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) {
type muxListener struct {
lock sync.Mutex
base net.Listener
acceptErr error
acceptChan chan net.Conn
closeChan chan struct{}
socksListener *subListener
httpListener *subListener
deleteFunc func()
}
func (l *muxListener) acceptLoop() {
defer close(l.acceptChan)
for {
conn, err := ln.Accept()
conn, err := l.base.Accept()
if err != nil {
sl.Close()
hl.Close()
l.lock.Lock()
l.acceptErr = err
l.lock.Unlock()
return
}
go routeConn(conn, sl, hl)
select {
case <-l.closeChan:
return
case l.acceptChan <- conn:
}
}
}
func routeConn(c net.Conn, socksListener, httpListener *listener) {
if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil {
c.Close()
return
}
func (l *muxListener) mainLoop() {
defer func() {
l.deleteFunc()
l.base.Close()
close(l.closeChan)
l.lock.Lock()
defer l.lock.Unlock()
if sl := l.httpListener; sl != nil {
close(sl.acceptChan)
l.httpListener = nil
}
if sl := l.socksListener; sl != nil {
close(sl.acceptChan)
l.socksListener = nil
}
}()
for {
var socksCloseChan, httpCloseChan chan struct{}
if l.httpListener != nil {
httpCloseChan = l.httpListener.closeChan
}
if l.socksListener != nil {
socksCloseChan = l.socksListener.closeChan
}
select {
case <-l.closeChan:
return
case conn, ok := <-l.acceptChan:
if !ok {
return
}
go l.dispatch(conn)
case <-socksCloseChan:
l.lock.Lock()
l.socksListener = nil
l.lock.Unlock()
if l.checkIdle() {
return
}
case <-httpCloseChan:
l.lock.Lock()
l.httpListener = nil
l.lock.Unlock()
if l.checkIdle() {
return
}
}
}
}
func (l *muxListener) dispatch(conn net.Conn) {
var b [1]byte
if _, err := io.ReadFull(c, b[:]); err != nil {
c.Close()
if _, err := io.ReadFull(conn, b[:]); err != nil {
conn.Close()
return
}
if err := c.SetReadDeadline(time.Time{}); err != nil {
c.Close()
return
}
conn := &connWithOneByte{
Conn: c,
b: b[0],
}
// First byte of a SOCKS5 session is a version byte set to 5.
var ln *listener
l.lock.Lock()
var target *subListener
if b[0] == 5 {
ln = socksListener
target = l.socksListener
} else {
ln = httpListener
target = l.httpListener
}
l.lock.Unlock()
if target == nil {
conn.Close()
return
}
wconn := &connWithOneByte{Conn: conn, b: b[0]}
select {
case ln.c <- conn:
case <-ln.closed:
c.Close()
case <-target.closeChan:
case target.acceptChan <- wconn:
}
}
type listener struct {
addr net.Addr
c chan net.Conn
mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking.
closed chan struct{}
func (l *muxListener) checkIdle() bool {
l.lock.Lock()
defer l.lock.Unlock()
return l.httpListener == nil && l.socksListener == nil
}
func (ln *listener) Accept() (net.Conn, error) {
// Once closed, reliably stay closed, don't race with attempts at
// further connections.
func (l *muxListener) getAndClearAcceptError() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.acceptErr == nil {
return nil
}
err := l.acceptErr
l.acceptErr = nil
return err
}
func (l *muxListener) ListenHTTP() (net.Listener, error) {
l.lock.Lock()
defer l.lock.Unlock()
if l.httpListener != nil {
return nil, OpErr{
Addr: l.base.Addr(),
Protocol: "http",
Op: "bind-protocol",
Err: ErrProtocolInUse,
}
}
select {
case <-ln.closed:
case <-l.closeChan:
return nil, net.ErrClosed
default:
}
sl := newSubListener(l.getAndClearAcceptError, l.base.Addr)
l.httpListener = sl
return sl, nil
}
func (l *muxListener) ListenSOCKS() (net.Listener, error) {
l.lock.Lock()
defer l.lock.Unlock()
if l.socksListener != nil {
return nil, OpErr{
Addr: l.base.Addr(),
Protocol: "socks",
Op: "bind-protocol",
Err: ErrProtocolInUse,
}
}
select {
case ret := <-ln.c:
return ret, nil
case <-ln.closed:
case <-l.closeChan:
return nil, net.ErrClosed
default:
}
sl := newSubListener(l.getAndClearAcceptError, l.base.Addr)
l.socksListener = sl
return sl, nil
}
func newSubListener(acceptErrorFunc func() error, addrFunc func() net.Addr) *subListener {
return &subListener{
acceptChan: make(chan net.Conn),
acceptErrorFunc: acceptErrorFunc,
closeChan: make(chan struct{}),
addrFunc: addrFunc,
}
}
func (ln *listener) Close() error {
ln.mu.Lock()
defer ln.mu.Unlock()
type subListener struct {
// receive connections or closure from upstream
acceptChan chan net.Conn
// get an error of Accept() from upstream
acceptErrorFunc func() error
// notify upstream that we are closed
closeChan chan struct{}
// Listener.Addr() implementation of base listener
addrFunc func() net.Addr
}
func (l *subListener) Accept() (net.Conn, error) {
select {
case <-ln.closed:
// Already closed
default:
close(ln.closed)
case <-l.closeChan:
// closed by ourselves
return nil, net.ErrClosed
case conn, ok := <-l.acceptChan:
if !ok {
// closed by upstream
if acceptErr := l.acceptErrorFunc(); acceptErr != nil {
return nil, acceptErr
}
return nil, net.ErrClosed
}
return conn, nil
}
}
func (l *subListener) Addr() net.Addr {
return l.addrFunc()
}
// Close implements net.Listener.Close.
// Upstream should use close(l.acceptChan) instead.
func (l *subListener) Close() error {
close(l.closeChan)
return nil
}
func (ln *listener) Addr() net.Addr {
return ln.addr
}
// connWithOneByte is a net.Conn that returns b for the first read
// request, then forwards everything else to Conn.
type connWithOneByte struct {
@ -139,3 +272,20 @@ func (c *connWithOneByte) Read(bs []byte) (int, error) {
bs[0] = c.b
return 1, nil
}
type OpErr struct {
Addr net.Addr
Protocol string
Op string
Err error
}
func (m OpErr) Error() string {
return fmt.Sprintf("mux-listen: %s[%s]: %s: %v", m.Addr, m.Protocol, m.Op, m.Err)
}
func (m OpErr) Unwrap() error {
return m.Err
}
var ErrProtocolInUse = errors.New("protocol already in use")