package proxymux import ( "errors" "fmt" "io" "net" "sync" ) func newMuxListener(listener net.Listener, deleteFunc func()) *muxListener { l := &muxListener{ base: listener, acceptChan: make(chan net.Conn), closeChan: make(chan struct{}), deleteFunc: deleteFunc, } go l.acceptLoop() go l.mainLoop() return l } type muxListener struct { lock sync.Mutex base net.Listener acceptErr error acceptChan chan net.Conn closeChan chan struct{} socksListener *subListener httpListener *subListener deleteFunc func() } func (l *muxListener) acceptLoop() { defer close(l.acceptChan) for { conn, err := l.base.Accept() if err != nil { l.lock.Lock() l.acceptErr = err l.lock.Unlock() return } select { case <-l.closeChan: return case l.acceptChan <- conn: } } } func (l *muxListener) mainLoop() { defer func() { l.deleteFunc() l.base.Close() close(l.closeChan) l.lock.Lock() defer l.lock.Unlock() if sl := l.httpListener; sl != nil { close(sl.acceptChan) l.httpListener = nil } if sl := l.socksListener; sl != nil { close(sl.acceptChan) l.socksListener = nil } }() for { var socksCloseChan, httpCloseChan chan struct{} if l.httpListener != nil { httpCloseChan = l.httpListener.closeChan } if l.socksListener != nil { socksCloseChan = l.socksListener.closeChan } select { case <-l.closeChan: return case conn, ok := <-l.acceptChan: if !ok { return } go l.dispatch(conn) case <-socksCloseChan: l.lock.Lock() if socksCloseChan == l.socksListener.closeChan { // not replaced by another ListenSOCKS() l.socksListener = nil } l.lock.Unlock() if l.checkIdle() { return } case <-httpCloseChan: l.lock.Lock() if httpCloseChan == l.httpListener.closeChan { // not replaced by another ListenHTTP() l.httpListener = nil } l.lock.Unlock() if l.checkIdle() { return } } } } func (l *muxListener) dispatch(conn net.Conn) { var b [1]byte if _, err := io.ReadFull(conn, b[:]); err != nil { conn.Close() return } l.lock.Lock() var target *subListener if b[0] == 5 { target = l.socksListener } else { target = l.httpListener } l.lock.Unlock() if target == nil { conn.Close() return } wconn := &connWithOneByte{Conn: conn, b: b[0]} select { case <-target.closeChan: case target.acceptChan <- wconn: } } func (l *muxListener) checkIdle() bool { l.lock.Lock() defer l.lock.Unlock() return l.httpListener == nil && l.socksListener == nil } func (l *muxListener) getAndClearAcceptError() error { l.lock.Lock() defer l.lock.Unlock() if l.acceptErr == nil { return nil } err := l.acceptErr l.acceptErr = nil return err } func (l *muxListener) ListenHTTP() (net.Listener, error) { l.lock.Lock() defer l.lock.Unlock() if l.httpListener != nil { subListenerPendingClosed := false select { case <-l.httpListener.closeChan: subListenerPendingClosed = true default: } if !subListenerPendingClosed { return nil, OpErr{ Addr: l.base.Addr(), Protocol: "http", Op: "bind-protocol", Err: ErrProtocolInUse, } } l.httpListener = nil } select { case <-l.closeChan: return nil, net.ErrClosed default: } sl := newSubListener(l.getAndClearAcceptError, l.base.Addr) l.httpListener = sl return sl, nil } func (l *muxListener) ListenSOCKS() (net.Listener, error) { l.lock.Lock() defer l.lock.Unlock() if l.socksListener != nil { subListenerPendingClosed := false select { case <-l.socksListener.closeChan: subListenerPendingClosed = true default: } if !subListenerPendingClosed { return nil, OpErr{ Addr: l.base.Addr(), Protocol: "socks", Op: "bind-protocol", Err: ErrProtocolInUse, } } l.socksListener = nil } select { case <-l.closeChan: return nil, net.ErrClosed default: } sl := newSubListener(l.getAndClearAcceptError, l.base.Addr) l.socksListener = sl return sl, nil } func newSubListener(acceptErrorFunc func() error, addrFunc func() net.Addr) *subListener { return &subListener{ acceptChan: make(chan net.Conn), acceptErrorFunc: acceptErrorFunc, closeChan: make(chan struct{}), addrFunc: addrFunc, } } type subListener struct { // receive connections or closure from upstream acceptChan chan net.Conn // get an error of Accept() from upstream acceptErrorFunc func() error // notify upstream that we are closed closeChan chan struct{} // Listener.Addr() implementation of base listener addrFunc func() net.Addr } func (l *subListener) Accept() (net.Conn, error) { select { case <-l.closeChan: // closed by ourselves return nil, net.ErrClosed case conn, ok := <-l.acceptChan: if !ok { // closed by upstream if acceptErr := l.acceptErrorFunc(); acceptErr != nil { return nil, acceptErr } return nil, net.ErrClosed } return conn, nil } } func (l *subListener) Addr() net.Addr { return l.addrFunc() } // Close implements net.Listener.Close. // Upstream should use close(l.acceptChan) instead. func (l *subListener) Close() error { select { case <-l.closeChan: return nil default: } close(l.closeChan) return nil } // connWithOneByte is a net.Conn that returns b for the first read // request, then forwards everything else to Conn. type connWithOneByte struct { net.Conn b byte bRead bool } func (c *connWithOneByte) Read(bs []byte) (int, error) { if c.bRead { return c.Conn.Read(bs) } if len(bs) == 0 { return 0, nil } c.bRead = true bs[0] = c.b return 1, nil } type OpErr struct { Addr net.Addr Protocol string Op string Err error } func (m OpErr) Error() string { return fmt.Sprintf("mux-listen: %s[%s]: %s: %v", m.Addr, m.Protocol, m.Op, m.Err) } func (m OpErr) Unwrap() error { return m.Err } var ErrProtocolInUse = errors.New("protocol already in use")