diff --git a/app/cmd/ping.go b/app/cmd/ping.go index 3080551..ccaf870 100644 --- a/app/cmd/ping.go +++ b/app/cmd/ping.go @@ -50,7 +50,7 @@ func runPing(cmd *cobra.Command, args []string) { logger.Info("connecting", zap.String("address", addr)) start := time.Now() - conn, err := c.DialTCP(addr) + conn, err := c.TCP(addr) if err != nil { logger.Fatal("failed to connect", zap.Error(err), zap.String("time", time.Since(start).String())) } diff --git a/app/internal/forwarding/tcp.go b/app/internal/forwarding/tcp.go index da21bdb..7d22d33 100644 --- a/app/internal/forwarding/tcp.go +++ b/app/internal/forwarding/tcp.go @@ -41,7 +41,7 @@ func (t *TCPTunnel) handle(conn net.Conn) { } }() - rc, err := t.HyClient.DialTCP(t.Remote) + rc, err := t.HyClient.TCP(t.Remote) if err != nil { closeErr = err return diff --git a/app/internal/forwarding/udp.go b/app/internal/forwarding/udp.go index 2bf46f4..93d73ec 100644 --- a/app/internal/forwarding/udp.go +++ b/app/internal/forwarding/udp.go @@ -118,7 +118,7 @@ func (t *UDPTunnel) handle(l net.PacketConn, sm *sessionManager, addr net.Addr, if t.EventLogger != nil { t.EventLogger.Connect(addr) } - hyConn, err := t.HyClient.ListenUDP() + hyConn, err := t.HyClient.UDP() if err != nil { if t.EventLogger != nil { t.EventLogger.Error(addr, err) diff --git a/app/internal/http/server.go b/app/internal/http/server.go index ba56312..4f11d14 100644 --- a/app/internal/http/server.go +++ b/app/internal/http/server.go @@ -156,7 +156,7 @@ func (s *Server) handleConnect(conn net.Conn, req *http.Request) { }() // Dial - rConn, err := s.HyClient.DialTCP(reqAddr) + rConn, err := s.HyClient.TCP(reqAddr) if err != nil { _ = sendSimpleResponse(conn, req, http.StatusBadGateway) closeErr = err @@ -233,7 +233,7 @@ func (s *Server) initHTTPClient() { Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { // HyClient doesn't support context for now - return s.HyClient.DialTCP(addr) + return s.HyClient.TCP(addr) }, }, CheckRedirect: func(req *http.Request, via []*http.Request) error { diff --git a/app/internal/http/server_test.go b/app/internal/http/server_test.go index 878362f..fbd24c5 100644 --- a/app/internal/http/server_test.go +++ b/app/internal/http/server_test.go @@ -17,11 +17,11 @@ const ( type mockHyClient struct{} -func (c *mockHyClient) DialTCP(addr string) (net.Conn, error) { +func (c *mockHyClient) TCP(addr string) (net.Conn, error) { return net.Dial("tcp", addr) } -func (c *mockHyClient) ListenUDP() (client.HyUDPConn, error) { +func (c *mockHyClient) UDP() (client.HyUDPConn, error) { // Not implemented return nil, errors.New("not implemented") } diff --git a/app/internal/socks5/server.go b/app/internal/socks5/server.go index 9c6ce24..84b58ed 100644 --- a/app/internal/socks5/server.go +++ b/app/internal/socks5/server.go @@ -135,7 +135,7 @@ func (s *Server) handleTCP(conn net.Conn, req *socks5.Request) { }() // Dial - rConn, err := s.HyClient.DialTCP(addr) + rConn, err := s.HyClient.TCP(addr) if err != nil { _ = sendSimpleReply(conn, socks5.RepHostUnreachable) closeErr = err @@ -196,7 +196,7 @@ func (s *Server) handleUDP(conn net.Conn, req *socks5.Request) { defer udpConn.Close() // HyClient UDP session - hyUDP, err := s.HyClient.ListenUDP() + hyUDP, err := s.HyClient.UDP() if err != nil { _ = sendSimpleReply(conn, socks5.RepServerFailure) closeErr = err diff --git a/app/internal/utils_test/mock.go b/app/internal/utils_test/mock.go index 4e04d85..04adad6 100644 --- a/app/internal/utils_test/mock.go +++ b/app/internal/utils_test/mock.go @@ -10,13 +10,13 @@ import ( type MockEchoHyClient struct{} -func (c *MockEchoHyClient) DialTCP(addr string) (net.Conn, error) { +func (c *MockEchoHyClient) TCP(addr string) (net.Conn, error) { return &mockEchoTCPConn{ BufChan: make(chan []byte, 10), }, nil } -func (c *MockEchoHyClient) ListenUDP() (client.HyUDPConn, error) { +func (c *MockEchoHyClient) UDP() (client.HyUDPConn, error) { return &mockEchoUDPConn{ BufChan: make(chan mockEchoUDPPacket, 10), }, nil diff --git a/core/client/client.go b/core/client/client.go index 3e9e259..6602b87 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -23,8 +23,8 @@ const ( ) type Client interface { - DialTCP(addr string) (net.Conn, error) - ListenUDP() (HyUDPConn, error) + TCP(addr string) (net.Conn, error) + UDP() (HyUDPConn, error) Close() error } @@ -146,7 +146,7 @@ func (c *clientImpl) openStream() (quic.Stream, error) { return &utils.QStream{Stream: stream}, nil } -func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { +func (c *clientImpl) TCP(addr string) (net.Conn, error) { stream, err := c.openStream() if err != nil { if netErr, ok := err.(net.Error); ok && !netErr.Temporary() { @@ -190,7 +190,7 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { }, nil } -func (c *clientImpl) ListenUDP() (HyUDPConn, error) { +func (c *clientImpl) UDP() (HyUDPConn, error) { if c.udpSM == nil { return nil, coreErrs.DialError{Message: "UDP not enabled"} } diff --git a/core/internal/integration_tests/.mockery.yaml b/core/internal/integration_tests/.mockery.yaml index 42ff698..b1da36f 100644 --- a/core/internal/integration_tests/.mockery.yaml +++ b/core/internal/integration_tests/.mockery.yaml @@ -2,8 +2,19 @@ with-expecter: true dir: mocks outpkg: mocks packages: + net: + interfaces: + Conn: + config: + mockname: MockConn github.com/apernet/hysteria/core/server: interfaces: + Outbound: + config: + mockname: MockOutbound + UDPConn: + config: + mockname: MockUDPConn Authenticator: config: mockname: MockAuthenticator diff --git a/core/internal/integration_tests/close_test.go b/core/internal/integration_tests/close_test.go index 1114c23..531deb2 100644 --- a/core/internal/integration_tests/close_test.go +++ b/core/internal/integration_tests/close_test.go @@ -1,10 +1,9 @@ package integration_tests import ( - "crypto/rand" "io" - "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -15,17 +14,18 @@ import ( ) // TestClientServerTCPClose tests whether the client/server propagates the close of a connection correctly. -// In other words, closing one of the client/remote connections should cause the other to close as well. +// Closing one side of the connection should close the other side as well. func TestClientServerTCPClose(t *testing.T) { // Create server - udpAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14514} - udpConn, err := net.ListenUDP("udp", udpAddr) + udpConn, udpAddr, err := serverConn() assert.NoError(t, err) + serverOb := mocks.NewMockOutbound(t) auth := mocks.NewMockAuthenticator(t) auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") s, err := server.NewServer(&server.Config{ TLSConfig: serverTLSConfig(), Conn: udpConn, + Outbound: serverOb, Authenticator: auth, }) assert.NoError(t, err) @@ -40,123 +40,137 @@ func TestClientServerTCPClose(t *testing.T) { assert.NoError(t, err) defer c.Close() - t.Run("Close local", func(t *testing.T) { - // TCP sink server - sinkAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 33344} - sinkListener, err := net.ListenTCP("tcp", sinkAddr) - assert.NoError(t, err) - sinkCh := make(chan sinkEvent, 1) - sinkServer := &tcpSinkServer{ - Listener: sinkListener, - Ch: sinkCh, - } - defer sinkServer.Close() - go sinkServer.Serve() + addr := "hi-and-goodbye:2333" - // Generate some random data - sData := make([]byte, 1024000) - _, err = rand.Read(sData) - assert.NoError(t, err) - - // Dial and send data to TCP sink server - conn, err := c.DialTCP(sinkAddr.String()) - assert.NoError(t, err) - _, err = conn.Write(sData) - assert.NoError(t, err) - - // Close the connection - // This should cause the sink server to send an event to the channel - _ = conn.Close() - event := <-sinkCh - assert.NoError(t, event.Err) - assert.Equal(t, sData, event.Data) + // Test close from client side: + // Client creates a connection, writes something, then closes it. + // Server outbound connection should write the same thing, then close. + sobConn := mocks.NewMockConn(t) + sobConnCh := make(chan struct{}) // For close signal only + sobConn.EXPECT().Read(mock.Anything).RunAndReturn(func(bs []byte) (int, error) { + <-sobConnCh + return 0, io.EOF }) - - t.Run("Close remote", func(t *testing.T) { - // Generate some random data - sData := make([]byte, 1024000) - _, err = rand.Read(sData) - assert.NoError(t, err) - - // TCP sender server - senderAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 33345} - senderListener, err := net.ListenTCP("tcp", senderAddr) - assert.NoError(t, err) - - senderServer := &tcpSenderServer{ - Listener: senderListener, - Data: sData, - } - defer senderServer.Close() - go senderServer.Serve() - - // Dial and read data from TCP sender server - conn, err := c.DialTCP(senderAddr.String()) - assert.NoError(t, err) - defer conn.Close() - rData, err := io.ReadAll(conn) - assert.NoError(t, err) - assert.Equal(t, sData, rData) + sobConn.EXPECT().Write([]byte("happy")).Return(5, nil) + sobConn.EXPECT().Close().RunAndReturn(func() error { + close(sobConnCh) + return nil }) + serverOb.EXPECT().TCP(addr).Return(sobConn, nil).Once() + conn, err := c.TCP(addr) + assert.NoError(t, err) + _, err = conn.Write([]byte("happy")) + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + time.Sleep(1 * time.Second) + mock.AssertExpectationsForObjects(t, sobConn, serverOb) + + // Test close from server side: + // Client creates a connection. + // Server outbound connection reads something, then closes. + // Client connection should read the same thing, then close. + sobConn = mocks.NewMockConn(t) + sobConnCh2 := make(chan []byte, 1) + sobConn.EXPECT().Read(mock.Anything).RunAndReturn(func(bs []byte) (int, error) { + d := <-sobConnCh2 + if d == nil { + return 0, io.EOF + } else { + return copy(bs, d), nil + } + }) + sobConn.EXPECT().Close().Return(nil) + serverOb.EXPECT().TCP(addr).Return(sobConn, nil).Once() + conn, err = c.TCP(addr) + assert.NoError(t, err) + sobConnCh2 <- []byte("happy") + close(sobConnCh2) + bs, err := io.ReadAll(conn) + assert.NoError(t, err) + assert.Equal(t, "happy", string(bs)) + _ = conn.Close() } -// TestClientServerUDPClose is the same as TestClientServerTCPClose, but for UDP. -// Checking for UDP close is a bit tricky, so we will rely on the server event for now. -func TestClientServerUDPClose(t *testing.T) { - urCh := make(chan udpRequestEvent, 1) - ueCh := make(chan udpErrorEvent, 1) - +// TestServerUDPIdleTimeout tests whether the server's UDP idle timeout works correctly. +func TestServerUDPIdleTimeout(t *testing.T) { // Create server - udpAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14514} - udpConn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - t.Fatal("error creating server:", err) - } + udpConn, udpAddr, err := serverConn() + assert.NoError(t, err) + serverOb := mocks.NewMockOutbound(t) + auth := mocks.NewMockAuthenticator(t) + auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") + eventLogger := mocks.NewMockEventLogger(t) + eventLogger.EXPECT().Connect(mock.Anything, "nobody", mock.Anything).Once() + eventLogger.EXPECT().Disconnect(mock.Anything, "nobody", mock.Anything).Maybe() // Depends on the timing, don't care s, err := server.NewServer(&server.Config{ - TLSConfig: serverTLSConfig(), - Conn: udpConn, - Authenticator: &pwAuthenticator{ - Password: "password", - ID: "nobody", - }, - EventLogger: &channelEventLogger{ - UDPRequestEventCh: urCh, - UDPErrorEventCh: ueCh, - }, + TLSConfig: serverTLSConfig(), + Conn: udpConn, + Outbound: serverOb, + UDPIdleTimeout: 2 * time.Second, + Authenticator: auth, + EventLogger: eventLogger, }) - if err != nil { - t.Fatal("error creating server:", err) - } + assert.NoError(t, err) defer s.Close() go s.Serve() // Create client c, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, - Auth: "password", TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) - if err != nil { - t.Fatal("error creating client:", err) - } + assert.NoError(t, err) defer c.Close() - // Listen UDP and close it, then check the server events - conn, err := c.ListenUDP() - if err != nil { - t.Fatal("error listening UDP:", err) - } - _ = conn.Close() + addr := "spy.x.family:2023" - reqEvent := <-urCh - if reqEvent.ID != "nobody" { - t.Fatal("incorrect ID in request event") + // On the client side, create a UDP session and send a packet every 1 second, + // 4 packets in total. The server should have one UDP session and receive all + // 4 packets. Then the UDP connection on the server side will receive a packet + // every 1 second, 4 packets in total. The client session should receive all + // 4 packets. Then the session will be idle for 3 seconds - should be enough + // to trigger the server's UDP idle timeout. + sobConn := mocks.NewMockUDPConn(t) + sobConnCh := make(chan []byte, 1) + sobConn.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(bs []byte) (int, string, error) { + d := <-sobConnCh + if d == nil { + return 0, "", io.EOF + } else { + return copy(bs, d), addr, nil + } + }) + sobConn.EXPECT().WriteTo([]byte("happy"), addr).Return(5, nil).Times(4) + serverOb.EXPECT().UDP(addr).Return(sobConn, nil).Once() + eventLogger.EXPECT().UDPRequest(mock.Anything, mock.Anything, uint32(1), addr).Once() + cu, err := c.UDP() + assert.NoError(t, err) + // Client sends 4 packets + for i := 0; i < 4; i++ { + err = cu.Send([]byte("happy"), addr) + assert.NoError(t, err) + time.Sleep(1 * time.Second) } - errEvent := <-ueCh - if errEvent.ID != "nobody" { - t.Fatal("incorrect ID in error event") - } - if errEvent.Err != nil { - t.Fatal("non-nil error received from server:", errEvent.Err) + // Client receives 4 packets + go func() { + for i := 0; i < 4; i++ { + sobConnCh <- []byte("sad") + time.Sleep(1 * time.Second) + } + }() + for i := 0; i < 4; i++ { + bs, rAddr, err := cu.Receive() + assert.NoError(t, err) + assert.Equal(t, "sad", string(bs)) + assert.Equal(t, addr, rAddr) } + // Now we wait for 3 seconds, the server should close the UDP session. + sobConn.EXPECT().Close().RunAndReturn(func() error { + close(sobConnCh) + return nil + }) + eventLogger.EXPECT().UDPError(mock.Anything, mock.Anything, uint32(1), nil).Once() + time.Sleep(3 * time.Second) + mock.AssertExpectationsForObjects(t, sobConn, serverOb, eventLogger) } diff --git a/core/internal/integration_tests/masq_test.go b/core/internal/integration_tests/masq_test.go index 333be94..fb12e8f 100644 --- a/core/internal/integration_tests/masq_test.go +++ b/core/internal/integration_tests/masq_test.go @@ -6,10 +6,13 @@ import ( "net" "net/http" "net/url" - "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/apernet/hysteria/core/internal/integration_tests/mocks" "github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/server" @@ -22,22 +25,16 @@ import ( // confirm that the server does not expose itself to active probers. func TestServerMasquerade(t *testing.T) { // Create server - udpAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14514} - udpConn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - t.Fatal("error creating server:", err) - } + udpConn, udpAddr, err := serverConn() + assert.NoError(t, err) + auth := mocks.NewMockAuthenticator(t) + auth.EXPECT().Authenticate(mock.Anything, "", uint64(0)).Return(false, "").Once() s, err := server.NewServer(&server.Config{ - TLSConfig: serverTLSConfig(), - Conn: udpConn, - Authenticator: &pwAuthenticator{ - Password: "password", - ID: "nobody", - }, + TLSConfig: serverTLSConfig(), + Conn: udpConn, + Authenticator: auth, }) - if err != nil { - t.Fatal("error creating server:", err) - } + assert.NoError(t, err) defer s.Close() go s.Serve() @@ -71,39 +68,27 @@ func TestServerMasquerade(t *testing.T) { Header: make(http.Header), } resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatal("error sending request:", err) - } - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("expected status %d, got %d", http.StatusNotFound, resp.StatusCode) - } + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) for k := range resp.Header { - // Make sure no strange headers are sent - if strings.Contains(k, "Hysteria") { - t.Fatal("expected no Hysteria headers, got", k) - } + // Make sure no strange headers are sent by the server + assert.NotContains(t, k, "Hysteria") } buf := make([]byte, 1024) // We send a TCP request anyway, see if we get a response tcpStream, err := conn.OpenStream() - if err != nil { - t.Fatal("error opening stream:", err) - } + assert.NoError(t, err) defer tcpStream.Close() err = protocol.WriteTCPRequest(tcpStream, "www.google.com:443") - if err != nil { - t.Fatal("error sending request:", err) - } + assert.NoError(t, err) // We should receive nothing _ = tcpStream.SetReadDeadline(time.Now().Add(2 * time.Second)) n, err := tcpStream.Read(buf) - if n != 0 { - t.Fatal("expected no response, got", n) - } - if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() { - t.Fatal("expected timeout, got", err) - } + assert.Equal(t, 0, n) + nErr, ok := err.(net.Error) + assert.True(t, ok) + assert.True(t, nErr.Timeout()) } diff --git a/core/internal/integration_tests/mocks/mock_Conn.go b/core/internal/integration_tests/mocks/mock_Conn.go new file mode 100644 index 0000000..6840332 --- /dev/null +++ b/core/internal/integration_tests/mocks/mock_Conn.go @@ -0,0 +1,395 @@ +// Code generated by mockery v2.32.0. DO NOT EDIT. + +package mocks + +import ( + net "net" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// MockConn is an autogenerated mock type for the Conn type +type MockConn struct { + mock.Mock +} + +type MockConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockConn) EXPECT() *MockConn_Expecter { + return &MockConn_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockConn) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockConn_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockConn_Expecter) Close() *MockConn_Close_Call { + return &MockConn_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockConn_Close_Call) Run(run func()) *MockConn_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConn_Close_Call) Return(_a0 error) *MockConn_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_Close_Call) RunAndReturn(run func() error) *MockConn_Close_Call { + _c.Call.Return(run) + return _c +} + +// LocalAddr provides a mock function with given fields: +func (_m *MockConn) LocalAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// MockConn_LocalAddr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LocalAddr' +type MockConn_LocalAddr_Call struct { + *mock.Call +} + +// LocalAddr is a helper method to define mock.On call +func (_e *MockConn_Expecter) LocalAddr() *MockConn_LocalAddr_Call { + return &MockConn_LocalAddr_Call{Call: _e.mock.On("LocalAddr")} +} + +func (_c *MockConn_LocalAddr_Call) Run(run func()) *MockConn_LocalAddr_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConn_LocalAddr_Call) Return(_a0 net.Addr) *MockConn_LocalAddr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_LocalAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_LocalAddr_Call { + _c.Call.Return(run) + return _c +} + +// Read provides a mock function with given fields: b +func (_m *MockConn) Read(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok { + return rf(b) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockConn_Read_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Read' +type MockConn_Read_Call struct { + *mock.Call +} + +// Read is a helper method to define mock.On call +// - b []byte +func (_e *MockConn_Expecter) Read(b interface{}) *MockConn_Read_Call { + return &MockConn_Read_Call{Call: _e.mock.On("Read", b)} +} + +func (_c *MockConn_Read_Call) Run(run func(b []byte)) *MockConn_Read_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockConn_Read_Call) Return(n int, err error) *MockConn_Read_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockConn_Read_Call) RunAndReturn(run func([]byte) (int, error)) *MockConn_Read_Call { + _c.Call.Return(run) + return _c +} + +// RemoteAddr provides a mock function with given fields: +func (_m *MockConn) RemoteAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// MockConn_RemoteAddr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoteAddr' +type MockConn_RemoteAddr_Call struct { + *mock.Call +} + +// RemoteAddr is a helper method to define mock.On call +func (_e *MockConn_Expecter) RemoteAddr() *MockConn_RemoteAddr_Call { + return &MockConn_RemoteAddr_Call{Call: _e.mock.On("RemoteAddr")} +} + +func (_c *MockConn_RemoteAddr_Call) Run(run func()) *MockConn_RemoteAddr_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConn_RemoteAddr_Call) Return(_a0 net.Addr) *MockConn_RemoteAddr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_RemoteAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_RemoteAddr_Call { + _c.Call.Return(run) + return _c +} + +// SetDeadline provides a mock function with given fields: t +func (_m *MockConn) SetDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_SetDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDeadline' +type MockConn_SetDeadline_Call struct { + *mock.Call +} + +// SetDeadline is a helper method to define mock.On call +// - t time.Time +func (_e *MockConn_Expecter) SetDeadline(t interface{}) *MockConn_SetDeadline_Call { + return &MockConn_SetDeadline_Call{Call: _e.mock.On("SetDeadline", t)} +} + +func (_c *MockConn_SetDeadline_Call) Run(run func(t time.Time)) *MockConn_SetDeadline_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Time)) + }) + return _c +} + +func (_c *MockConn_SetDeadline_Call) Return(_a0 error) *MockConn_SetDeadline_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_SetDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetDeadline_Call { + _c.Call.Return(run) + return _c +} + +// SetReadDeadline provides a mock function with given fields: t +func (_m *MockConn) SetReadDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_SetReadDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetReadDeadline' +type MockConn_SetReadDeadline_Call struct { + *mock.Call +} + +// SetReadDeadline is a helper method to define mock.On call +// - t time.Time +func (_e *MockConn_Expecter) SetReadDeadline(t interface{}) *MockConn_SetReadDeadline_Call { + return &MockConn_SetReadDeadline_Call{Call: _e.mock.On("SetReadDeadline", t)} +} + +func (_c *MockConn_SetReadDeadline_Call) Run(run func(t time.Time)) *MockConn_SetReadDeadline_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Time)) + }) + return _c +} + +func (_c *MockConn_SetReadDeadline_Call) Return(_a0 error) *MockConn_SetReadDeadline_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_SetReadDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetReadDeadline_Call { + _c.Call.Return(run) + return _c +} + +// SetWriteDeadline provides a mock function with given fields: t +func (_m *MockConn) SetWriteDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockConn_SetWriteDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetWriteDeadline' +type MockConn_SetWriteDeadline_Call struct { + *mock.Call +} + +// SetWriteDeadline is a helper method to define mock.On call +// - t time.Time +func (_e *MockConn_Expecter) SetWriteDeadline(t interface{}) *MockConn_SetWriteDeadline_Call { + return &MockConn_SetWriteDeadline_Call{Call: _e.mock.On("SetWriteDeadline", t)} +} + +func (_c *MockConn_SetWriteDeadline_Call) Run(run func(t time.Time)) *MockConn_SetWriteDeadline_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Time)) + }) + return _c +} + +func (_c *MockConn_SetWriteDeadline_Call) Return(_a0 error) *MockConn_SetWriteDeadline_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConn_SetWriteDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetWriteDeadline_Call { + _c.Call.Return(run) + return _c +} + +// Write provides a mock function with given fields: b +func (_m *MockConn) Write(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok { + return rf(b) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockConn_Write_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Write' +type MockConn_Write_Call struct { + *mock.Call +} + +// Write is a helper method to define mock.On call +// - b []byte +func (_e *MockConn_Expecter) Write(b interface{}) *MockConn_Write_Call { + return &MockConn_Write_Call{Call: _e.mock.On("Write", b)} +} + +func (_c *MockConn_Write_Call) Run(run func(b []byte)) *MockConn_Write_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockConn_Write_Call) Return(n int, err error) *MockConn_Write_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockConn_Write_Call) RunAndReturn(run func([]byte) (int, error)) *MockConn_Write_Call { + _c.Call.Return(run) + return _c +} + +// NewMockConn creates a new instance of MockConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockConn { + mock := &MockConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/internal/integration_tests/mocks/mock_Outbound.go b/core/internal/integration_tests/mocks/mock_Outbound.go new file mode 100644 index 0000000..32a747b --- /dev/null +++ b/core/internal/integration_tests/mocks/mock_Outbound.go @@ -0,0 +1,146 @@ +// Code generated by mockery v2.32.0. DO NOT EDIT. + +package mocks + +import ( + net "net" + + mock "github.com/stretchr/testify/mock" + + server "github.com/apernet/hysteria/core/server" +) + +// MockOutbound is an autogenerated mock type for the Outbound type +type MockOutbound struct { + mock.Mock +} + +type MockOutbound_Expecter struct { + mock *mock.Mock +} + +func (_m *MockOutbound) EXPECT() *MockOutbound_Expecter { + return &MockOutbound_Expecter{mock: &_m.Mock} +} + +// TCP provides a mock function with given fields: reqAddr +func (_m *MockOutbound) TCP(reqAddr string) (net.Conn, error) { + ret := _m.Called(reqAddr) + + var r0 net.Conn + var r1 error + if rf, ok := ret.Get(0).(func(string) (net.Conn, error)); ok { + return rf(reqAddr) + } + if rf, ok := ret.Get(0).(func(string) net.Conn); ok { + r0 = rf(reqAddr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(reqAddr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOutbound_TCP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TCP' +type MockOutbound_TCP_Call struct { + *mock.Call +} + +// TCP is a helper method to define mock.On call +// - reqAddr string +func (_e *MockOutbound_Expecter) TCP(reqAddr interface{}) *MockOutbound_TCP_Call { + return &MockOutbound_TCP_Call{Call: _e.mock.On("TCP", reqAddr)} +} + +func (_c *MockOutbound_TCP_Call) Run(run func(reqAddr string)) *MockOutbound_TCP_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockOutbound_TCP_Call) Return(_a0 net.Conn, _a1 error) *MockOutbound_TCP_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOutbound_TCP_Call) RunAndReturn(run func(string) (net.Conn, error)) *MockOutbound_TCP_Call { + _c.Call.Return(run) + return _c +} + +// UDP provides a mock function with given fields: reqAddr +func (_m *MockOutbound) UDP(reqAddr string) (server.UDPConn, error) { + ret := _m.Called(reqAddr) + + var r0 server.UDPConn + var r1 error + if rf, ok := ret.Get(0).(func(string) (server.UDPConn, error)); ok { + return rf(reqAddr) + } + if rf, ok := ret.Get(0).(func(string) server.UDPConn); ok { + r0 = rf(reqAddr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(server.UDPConn) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(reqAddr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOutbound_UDP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UDP' +type MockOutbound_UDP_Call struct { + *mock.Call +} + +// UDP is a helper method to define mock.On call +// - reqAddr string +func (_e *MockOutbound_Expecter) UDP(reqAddr interface{}) *MockOutbound_UDP_Call { + return &MockOutbound_UDP_Call{Call: _e.mock.On("UDP", reqAddr)} +} + +func (_c *MockOutbound_UDP_Call) Run(run func(reqAddr string)) *MockOutbound_UDP_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockOutbound_UDP_Call) Return(_a0 server.UDPConn, _a1 error) *MockOutbound_UDP_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOutbound_UDP_Call) RunAndReturn(run func(string) (server.UDPConn, error)) *MockOutbound_UDP_Call { + _c.Call.Return(run) + return _c +} + +// NewMockOutbound creates a new instance of MockOutbound. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockOutbound(t interface { + mock.TestingT + Cleanup(func()) +}) *MockOutbound { + mock := &MockOutbound{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/internal/integration_tests/mocks/mock_UDPConn.go b/core/internal/integration_tests/mocks/mock_UDPConn.go new file mode 100644 index 0000000..808c18b --- /dev/null +++ b/core/internal/integration_tests/mocks/mock_UDPConn.go @@ -0,0 +1,185 @@ +// Code generated by mockery v2.32.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// MockUDPConn is an autogenerated mock type for the UDPConn type +type MockUDPConn struct { + mock.Mock +} + +type MockUDPConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockUDPConn) EXPECT() *MockUDPConn_Expecter { + return &MockUDPConn_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockUDPConn) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockUDPConn_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockUDPConn_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockUDPConn_Expecter) Close() *MockUDPConn_Close_Call { + return &MockUDPConn_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockUDPConn_Close_Call) Run(run func()) *MockUDPConn_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockUDPConn_Close_Call) Return(_a0 error) *MockUDPConn_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUDPConn_Close_Call) RunAndReturn(run func() error) *MockUDPConn_Close_Call { + _c.Call.Return(run) + return _c +} + +// ReadFrom provides a mock function with given fields: b +func (_m *MockUDPConn) ReadFrom(b []byte) (int, string, error) { + ret := _m.Called(b) + + var r0 int + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func([]byte) (int, string, error)); ok { + return rf(b) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) string); ok { + r1 = rf(b) + } else { + r1 = ret.Get(1).(string) + } + + if rf, ok := ret.Get(2).(func([]byte) error); ok { + r2 = rf(b) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUDPConn_ReadFrom_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadFrom' +type MockUDPConn_ReadFrom_Call struct { + *mock.Call +} + +// ReadFrom is a helper method to define mock.On call +// - b []byte +func (_e *MockUDPConn_Expecter) ReadFrom(b interface{}) *MockUDPConn_ReadFrom_Call { + return &MockUDPConn_ReadFrom_Call{Call: _e.mock.On("ReadFrom", b)} +} + +func (_c *MockUDPConn_ReadFrom_Call) Run(run func(b []byte)) *MockUDPConn_ReadFrom_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockUDPConn_ReadFrom_Call) Return(_a0 int, _a1 string, _a2 error) *MockUDPConn_ReadFrom_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUDPConn_ReadFrom_Call) RunAndReturn(run func([]byte) (int, string, error)) *MockUDPConn_ReadFrom_Call { + _c.Call.Return(run) + return _c +} + +// WriteTo provides a mock function with given fields: b, addr +func (_m *MockUDPConn) WriteTo(b []byte, addr string) (int, error) { + ret := _m.Called(b, addr) + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte, string) (int, error)); ok { + return rf(b, addr) + } + if rf, ok := ret.Get(0).(func([]byte, string) int); ok { + r0 = rf(b, addr) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte, string) error); ok { + r1 = rf(b, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockUDPConn_WriteTo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WriteTo' +type MockUDPConn_WriteTo_Call struct { + *mock.Call +} + +// WriteTo is a helper method to define mock.On call +// - b []byte +// - addr string +func (_e *MockUDPConn_Expecter) WriteTo(b interface{}, addr interface{}) *MockUDPConn_WriteTo_Call { + return &MockUDPConn_WriteTo_Call{Call: _e.mock.On("WriteTo", b, addr)} +} + +func (_c *MockUDPConn_WriteTo_Call) Run(run func(b []byte, addr string)) *MockUDPConn_WriteTo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte), args[1].(string)) + }) + return _c +} + +func (_c *MockUDPConn_WriteTo_Call) Return(_a0 int, _a1 error) *MockUDPConn_WriteTo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockUDPConn_WriteTo_Call) RunAndReturn(run func([]byte, string) (int, error)) *MockUDPConn_WriteTo_Call { + _c.Call.Return(run) + return _c +} + +// NewMockUDPConn creates a new instance of MockUDPConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockUDPConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockUDPConn { + mock := &MockUDPConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/internal/integration_tests/smoke_test.go b/core/internal/integration_tests/smoke_test.go index eac2b53..30984f9 100644 --- a/core/internal/integration_tests/smoke_test.go +++ b/core/internal/integration_tests/smoke_test.go @@ -28,13 +28,13 @@ func TestClientNoServer(t *testing.T) { var cErr *coreErrs.ConnectError // Try TCP - _, err = c.DialTCP("google.com:443") + _, err = c.TCP("google.com:443") if !errors.As(err, &cErr) { - t.Fatal("expected connect error from DialTCP") + t.Fatal("expected connect error from TCP") } // Try UDP - _, err = c.ListenUDP() + _, err = c.UDP() if !errors.As(err, &cErr) { t.Fatal("expected connect error from DialUDP") } @@ -78,13 +78,13 @@ func TestClientServerBadAuth(t *testing.T) { var aErr *coreErrs.AuthError // Try TCP - _, err = c.DialTCP("google.com:443") + _, err = c.TCP("google.com:443") if !errors.As(err, &aErr) { - t.Fatal("expected auth error from DialTCP") + t.Fatal("expected auth error from TCP") } // Try UDP - _, err = c.ListenUDP() + _, err = c.UDP() if !errors.As(err, &aErr) { t.Fatal("expected auth error from DialUDP") } @@ -134,7 +134,7 @@ func TestClientServerTCPEcho(t *testing.T) { defer c.Close() // Dial TCP - conn, err := c.DialTCP(echoTCPAddr.String()) + conn, err := c.TCP(echoTCPAddr.String()) if err != nil { t.Fatal("error dialing TCP:", err) } @@ -200,7 +200,7 @@ func TestClientServerUDPEcho(t *testing.T) { defer c.Close() // Listen UDP - conn, err := c.ListenUDP() + conn, err := c.UDP() if err != nil { t.Fatal("error listening UDP:", err) } diff --git a/core/internal/integration_tests/stress_test.go b/core/internal/integration_tests/stress_test.go index eebb5ee..247d0c2 100644 --- a/core/internal/integration_tests/stress_test.go +++ b/core/internal/integration_tests/stress_test.go @@ -172,7 +172,7 @@ func TestClientServerTCPStress(t *testing.T) { defer c.Close() dialFunc := func() (net.Conn, error) { - return c.DialTCP(echoTCPAddr.String()) + return c.TCP(echoTCPAddr.String()) } t.Run("Single 500m", (&tcpStressor{DialFunc: dialFunc, Size: 524288000, Parallel: 1, Iterations: 1}).Run) @@ -227,7 +227,7 @@ func TestClientServerUDPStress(t *testing.T) { defer c.Close() t.Run("Single 1000x100b", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 100, Count: 1000, @@ -235,7 +235,7 @@ func TestClientServerUDPStress(t *testing.T) { Iterations: 1, }).Run) t.Run("Single 1000x3k", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 3000, Count: 1000, @@ -244,7 +244,7 @@ func TestClientServerUDPStress(t *testing.T) { }).Run) t.Run("5 Sequential 1000x100b", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 100, Count: 1000, @@ -252,7 +252,7 @@ func TestClientServerUDPStress(t *testing.T) { Iterations: 5, }).Run) t.Run("5 Sequential 200x3k", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 3000, Count: 200, @@ -261,7 +261,7 @@ func TestClientServerUDPStress(t *testing.T) { }).Run) t.Run("2 Sequential 5 Parallel 1000x100b", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 100, Count: 1000, @@ -270,7 +270,7 @@ func TestClientServerUDPStress(t *testing.T) { }).Run) t.Run("2 Sequential 5 Parallel 200x3k", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 3000, Count: 200, @@ -279,7 +279,7 @@ func TestClientServerUDPStress(t *testing.T) { }).Run) t.Run("10 Sequential 5 Parallel 200x3k", (&udpStressor{ - ListenFunc: c.ListenUDP, + ListenFunc: c.UDP, ServerAddr: echoUDPAddr.String(), Size: 3000, Count: 200, diff --git a/core/internal/integration_tests/trafficlogger_test.go b/core/internal/integration_tests/trafficlogger_test.go index 3d79f8f..6787042 100644 --- a/core/internal/integration_tests/trafficlogger_test.go +++ b/core/internal/integration_tests/trafficlogger_test.go @@ -84,7 +84,7 @@ func TestServerTrafficLogger(t *testing.T) { defer c.Close() // Dial TCP - tConn, err := c.DialTCP(echoTCPAddr.String()) + tConn, err := c.TCP(echoTCPAddr.String()) if err != nil { t.Fatal("error dialing TCP:", err) } @@ -124,7 +124,7 @@ func TestServerTrafficLogger(t *testing.T) { go uEchoServer.Serve() // Listen UDP - uConn, err := c.ListenUDP() + uConn, err := c.UDP() if err != nil { t.Fatal("error listening UDP:", err) } diff --git a/core/internal/integration_tests/utils_test.go b/core/internal/integration_tests/utils_test.go index f87dc24..60bfc52 100644 --- a/core/internal/integration_tests/utils_test.go +++ b/core/internal/integration_tests/utils_test.go @@ -26,6 +26,15 @@ func serverTLSConfig() server.TLSConfig { } } +func serverConn() (net.PacketConn, net.Addr, error) { + udpAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14514} + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, nil, err + } + return udpConn, udpAddr, nil +} + type pwAuthenticator struct { Password string ID string diff --git a/core/server/config.go b/core/server/config.go index 3eb183d..368da24 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -15,6 +15,7 @@ const ( defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB defaultMaxIdleTimeout = 30 * time.Second defaultMaxIncomingStreams = 1024 + defaultUDPIdleTimeout = 60 * time.Second ) type Config struct { @@ -24,6 +25,7 @@ type Config struct { Outbound Outbound BandwidthConfig BandwidthConfig DisableUDP bool + UDPIdleTimeout time.Duration Authenticator Authenticator EventLogger EventLogger TrafficLogger TrafficLogger @@ -79,6 +81,11 @@ func (c *Config) fill() error { if c.BandwidthConfig.MaxRx != 0 && c.BandwidthConfig.MaxRx < 65536 { return errors.ConfigError{Field: "BandwidthConfig.MaxRx", Reason: "must be at least 65536"} } + if c.UDPIdleTimeout == 0 { + c.UDPIdleTimeout = defaultUDPIdleTimeout + } else if c.UDPIdleTimeout < 2*time.Second || c.UDPIdleTimeout > 600*time.Second { + return errors.ConfigError{Field: "UDPIdleTimeout", Reason: "must be between 2s and 600s"} + } if c.Authenticator == nil { return errors.ConfigError{Field: "Authenticator", Reason: "must be set"} } diff --git a/core/server/server.go b/core/server/server.go index f780540..93d3d4a 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -5,21 +5,18 @@ import ( "crypto/tls" "net/http" "sync" - "time" + + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" "github.com/apernet/hysteria/core/internal/congestion" "github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/utils" - - "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/http3" ) const ( closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError closeErrCodeTrafficLimitReached = 0x107 // HTTP3 ErrCodeExcessiveLoad - - udpSessionIdleTimeout = 60 * time.Second ) type Server interface { @@ -148,7 +145,7 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sm := newUDPSessionManager( &udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound}, &udpEventLoggerImpl{h.conn, id, h.config.EventLogger}, - udpSessionIdleTimeout) + h.config.UDPIdleTimeout) h.udpSM = sm go sm.Run() }) diff --git a/core/server/udp_test.go b/core/server/udp_test.go index cb1cb60..ccfb8cf 100644 --- a/core/server/udp_test.go +++ b/core/server/udp_test.go @@ -170,6 +170,6 @@ func TestUDPSessionManager(t *testing.T) { // Leak checks close(msgCh) // This will return error from ReceiveMessage(), should stop the session manager time.Sleep(1 * time.Second) // Wait one more second just to be sure - assert.Equal(t, sm.Count(), 0, "session count should be 0") + assert.Zero(t, sm.Count(), "session count should be 0") goleak.VerifyNone(t) } diff --git a/extras/outbounds/interface_test.go b/extras/outbounds/interface_test.go index aa2aa12..3734b66 100644 --- a/extras/outbounds/interface_test.go +++ b/extras/outbounds/interface_test.go @@ -58,15 +58,15 @@ func TestPluggableOutboundAdapter(t *testing.T) { adapter := &PluggableOutboundAdapter{ PluggableOutbound: &mockPluggableOutbound{}, } - // DialTCP with correct addr + // TCP with correct addr _, err := adapter.DialTCP("correct_host_1:34567") if err != nil { - t.Fatal("DialTCP with correct addr failed", err) + t.Fatal("TCP with correct addr failed", err) } - // DialTCP with wrong addr + // TCP with wrong addr _, err = adapter.DialTCP("wrong_host_1:34567") if err != errWrongAddr { - t.Fatal("DialTCP with wrong addr should fail, got", err) + t.Fatal("TCP with wrong addr should fail, got", err) } // DialUDP uConn, err := adapter.DialUDP()