diff --git a/app/internal/proxymux/manager_test.go b/app/internal/proxymux/manager_test.go index 1d5a7af..c776058 100644 --- a/app/internal/proxymux/manager_test.go +++ b/app/internal/proxymux/manager_test.go @@ -31,9 +31,6 @@ func TestListenSOCKS(t *testing.T) { } sl.Close() - // Wait for muxListener.socksListener released - time.Sleep(time.Second) - sl, err = ListenSOCKS(address) if !assert.NoError(t, err) { return @@ -63,9 +60,6 @@ func TestListenHTTP(t *testing.T) { } hl.Close() - // Wait for muxListener.socksListener released - time.Sleep(time.Second) - hl, err = ListenHTTP(address) if !assert.NoError(t, err) { return diff --git a/app/internal/proxymux/mux_test.go b/app/internal/proxymux/mux_test.go index 0e24f95..7a42b83 100644 --- a/app/internal/proxymux/mux_test.go +++ b/app/internal/proxymux/mux_test.go @@ -1,7 +1,9 @@ package proxymux import ( + "bytes" "net" + "sync" "testing" "time" @@ -13,25 +15,52 @@ import ( //go:generate mockery -func testMockListener(t *testing.T, firstByte byte) net.Listener { - mockConn := mocks.NewMockConn(t) - mockConn.EXPECT().Read(mock.Anything).RunAndReturn(func(b []byte) (int, error) { - b[0] = firstByte - return 1, nil - }) - mockConn.EXPECT().Close().Return(nil) +func testMockListener(t *testing.T, connChan <-chan net.Conn) net.Listener { + closedChan := make(chan struct{}) mockListener := mocks.NewMockListener(t) mockListener.EXPECT().Accept().RunAndReturn(func() (net.Conn, error) { - // Wait for all listener set up - time.Sleep(200 * time.Millisecond) - return mockConn, nil + select { + case <-closedChan: + return nil, net.ErrClosed + case conn, ok := <-connChan: + if !ok { + panic("unexpected closed channel (connChan)") + } + return conn, nil + } + }) + mockListener.EXPECT().Close().RunAndReturn(func() error { + select { + case <-closedChan: + default: + close(closedChan) + } + return nil }) - mockListener.EXPECT().Close().Return(nil) return mockListener } +func testMockConn(t *testing.T, b []byte) net.Conn { + buf := bytes.NewReader(b) + isClosed := false + mockConn := mocks.NewMockConn(t) + mockConn.EXPECT().Read(mock.Anything).RunAndReturn(func(b []byte) (int, error) { + if isClosed { + return 0, net.ErrClosed + } + return buf.Read(b) + }) + mockConn.EXPECT().Close().RunAndReturn(func() error { + isClosed = true + return nil + }) + return mockConn +} + func TestMuxHTTP(t *testing.T) { - mockListener := testMockListener(t, 'C') + connChan := make(chan net.Conn) + mockListener := testMockListener(t, connChan) + mockConn := testMockConn(t, []byte("CONNECT example.com:443 HTTP/1.1\r\n\r\n")) mux := newMuxListener(mockListener, func() {}) hl, err := mux.ListenHTTP() @@ -43,22 +72,28 @@ func TestMuxHTTP(t *testing.T) { return } + connChan <- mockConn + var socksConn, httpConn net.Conn var socksErr, httpErr error + var wg sync.WaitGroup + wg.Add(2) go func() { socksConn, socksErr = sl.Accept() + wg.Done() }() - go func() { httpConn, httpErr = hl.Accept() + wg.Done() }() - time.Sleep(1 * time.Second) + time.Sleep(time.Second) + sl.Close() hl.Close() - // Wait for unmatched handler error - time.Sleep(1 * time.Second) + + wg.Wait() assert.Nil(t, socksConn) assert.ErrorIs(t, socksErr, net.ErrClosed) @@ -67,11 +102,13 @@ func TestMuxHTTP(t *testing.T) { assert.NoError(t, httpErr) // Wait for muxListener released - time.Sleep(time.Second) + <-mux.acceptChan } func TestMuxSOCKS(t *testing.T) { - mockListener := testMockListener(t, '\x05') + connChan := make(chan net.Conn) + mockListener := testMockListener(t, connChan) + mockConn := testMockConn(t, []byte{0x05, 0x02, 0x00, 0x01}) // SOCKS5 Connect Request: NOAUTH+GSSAPI mux := newMuxListener(mockListener, func() {}) hl, err := mux.ListenHTTP() @@ -83,22 +120,28 @@ func TestMuxSOCKS(t *testing.T) { return } + connChan <- mockConn + var socksConn, httpConn net.Conn var socksErr, httpErr error + var wg sync.WaitGroup + wg.Add(2) go func() { socksConn, socksErr = sl.Accept() + wg.Done() }() - go func() { httpConn, httpErr = hl.Accept() + wg.Done() }() - time.Sleep(1 * time.Second) + time.Sleep(time.Second) + sl.Close() hl.Close() - // Wait for unmatched handler error - time.Sleep(1 * time.Second) + + wg.Wait() assert.NotNil(t, socksConn) socksConn.Close() @@ -107,5 +150,5 @@ func TestMuxSOCKS(t *testing.T) { assert.ErrorIs(t, httpErr, net.ErrClosed) // Wait for muxListener released - time.Sleep(time.Second) + <-mux.acceptChan }