package proxymux import ( "bytes" "net" "sync" "testing" "time" "github.com/apernet/hysteria/app/v2/internal/proxymux/internal/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) //go:generate mockery func testMockListener(t *testing.T, connChan <-chan net.Conn) net.Listener { closedChan := make(chan struct{}) mockListener := mocks.NewMockListener(t) mockListener.EXPECT().Accept().RunAndReturn(func() (net.Conn, error) { select { case <-closedChan: return nil, net.ErrClosed case conn, ok := <-connChan: if !ok { panic("unexpected closed channel (connChan)") } return conn, nil } }) mockListener.EXPECT().Close().RunAndReturn(func() error { select { case <-closedChan: default: close(closedChan) } return nil }) return mockListener } func testMockConn(t *testing.T, b []byte) net.Conn { buf := bytes.NewReader(b) isClosed := false mockConn := mocks.NewMockConn(t) mockConn.EXPECT().Read(mock.Anything).RunAndReturn(func(b []byte) (int, error) { if isClosed { return 0, net.ErrClosed } return buf.Read(b) }) mockConn.EXPECT().Close().RunAndReturn(func() error { isClosed = true return nil }) return mockConn } func TestMuxHTTP(t *testing.T) { connChan := make(chan net.Conn) mockListener := testMockListener(t, connChan) mockConn := testMockConn(t, []byte("CONNECT example.com:443 HTTP/1.1\r\n\r\n")) mux := newMuxListener(mockListener, func() {}) hl, err := mux.ListenHTTP() if !assert.NoError(t, err) { return } sl, err := mux.ListenSOCKS() if !assert.NoError(t, err) { return } connChan <- mockConn var socksConn, httpConn net.Conn var socksErr, httpErr error var wg sync.WaitGroup wg.Add(2) go func() { socksConn, socksErr = sl.Accept() wg.Done() }() go func() { httpConn, httpErr = hl.Accept() wg.Done() }() time.Sleep(time.Second) sl.Close() hl.Close() wg.Wait() assert.Nil(t, socksConn) assert.ErrorIs(t, socksErr, net.ErrClosed) assert.NotNil(t, httpConn) httpConn.Close() assert.NoError(t, httpErr) // Wait for muxListener released <-mux.acceptChan } func TestMuxSOCKS(t *testing.T) { connChan := make(chan net.Conn) mockListener := testMockListener(t, connChan) mockConn := testMockConn(t, []byte{0x05, 0x02, 0x00, 0x01}) // SOCKS5 Connect Request: NOAUTH+GSSAPI mux := newMuxListener(mockListener, func() {}) hl, err := mux.ListenHTTP() if !assert.NoError(t, err) { return } sl, err := mux.ListenSOCKS() if !assert.NoError(t, err) { return } connChan <- mockConn var socksConn, httpConn net.Conn var socksErr, httpErr error var wg sync.WaitGroup wg.Add(2) go func() { socksConn, socksErr = sl.Accept() wg.Done() }() go func() { httpConn, httpErr = hl.Accept() wg.Done() }() time.Sleep(time.Second) sl.Close() hl.Close() wg.Wait() assert.NotNil(t, socksConn) socksConn.Close() assert.NoError(t, socksErr) assert.Nil(t, httpConn) assert.ErrorIs(t, httpErr, net.ErrClosed) // Wait for muxListener released <-mux.acceptChan }