diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index bfbaee73..a025eb8b 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -41,6 +41,42 @@ func (m *MockSendConn) EXPECT() *MockSendConnMockRecorder { return m.recorder } +// ChangeRemoteAddr mocks base method. +func (m *MockSendConn) ChangeRemoteAddr(addr net.Addr, info packetInfo) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ChangeRemoteAddr", addr, info) +} + +// ChangeRemoteAddr indicates an expected call of ChangeRemoteAddr. +func (mr *MockSendConnMockRecorder) ChangeRemoteAddr(addr, info any) *MockSendConnChangeRemoteAddrCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeRemoteAddr", reflect.TypeOf((*MockSendConn)(nil).ChangeRemoteAddr), addr, info) + return &MockSendConnChangeRemoteAddrCall{Call: call} +} + +// MockSendConnChangeRemoteAddrCall wrap *gomock.Call +type MockSendConnChangeRemoteAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnChangeRemoteAddrCall) Return() *MockSendConnChangeRemoteAddrCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnChangeRemoteAddrCall) Do(f func(net.Addr, packetInfo)) *MockSendConnChangeRemoteAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnChangeRemoteAddrCall) DoAndReturn(f func(net.Addr, packetInfo)) *MockSendConnChangeRemoteAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Close mocks base method. func (m *MockSendConn) Close() error { m.ctrl.T.Helper() @@ -193,6 +229,44 @@ func (c *MockSendConnWriteCall) DoAndReturn(f func([]byte, uint16, protocol.ECN) return c } +// WriteTo mocks base method. +func (m *MockSendConn) WriteTo(arg0 []byte, arg1 net.Addr) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteTo", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteTo indicates an expected call of WriteTo. +func (mr *MockSendConnMockRecorder) WriteTo(arg0, arg1 any) *MockSendConnWriteToCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockSendConn)(nil).WriteTo), arg0, arg1) + return &MockSendConnWriteToCall{Call: call} +} + +// MockSendConnWriteToCall wrap *gomock.Call +type MockSendConnWriteToCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnWriteToCall) Return(arg0 error) *MockSendConnWriteToCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnWriteToCall) Do(f func([]byte, net.Addr) error) *MockSendConnWriteToCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnWriteToCall) DoAndReturn(f func([]byte, net.Addr) error) *MockSendConnWriteToCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // capabilities mocks base method. func (m *MockSendConn) capabilities() connCapabilities { m.ctrl.T.Helper() diff --git a/send_conn.go b/send_conn.go index 498ed112..402520c6 100644 --- a/send_conn.go +++ b/send_conn.go @@ -2,6 +2,7 @@ package quic import ( "net" + "sync/atomic" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" @@ -10,22 +11,29 @@ import ( // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { Write(b []byte, gsoSize uint16, ecn protocol.ECN) error + WriteTo([]byte, net.Addr) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr + ChangeRemoteAddr(addr net.Addr, info packetInfo) capabilities() connCapabilities } +type remoteAddrInfo struct { + addr net.Addr + oob []byte +} + type sconn struct { rawConn - localAddr net.Addr - remoteAddr net.Addr + localAddr net.Addr + + remoteAddrInfo atomic.Pointer[remoteAddrInfo] logger utils.Logger - packetInfoOOB []byte // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool // Used to catch the error sometimes returned by the first sendmsg call on Linux, @@ -49,22 +57,26 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) oob = append(oob, make([]byte, 64)...)[:l] - return &sconn{ - rawConn: c, - localAddr: localAddr, - remoteAddr: remote, - packetInfoOOB: oob, - logger: logger, + sc := &sconn{ + rawConn: c, + localAddr: localAddr, + logger: logger, } + sc.remoteAddrInfo.Store(&remoteAddrInfo{ + addr: remote, + oob: oob, + }) + return sc } func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { - err := c.writePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) + ai := c.remoteAddrInfo.Load() + err := c.writePacket(p, ai.addr, ai.oob, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true if c.logger.Debug() { - c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr) + c.logger.Debugf("GSO failed when sending to %s", ai.addr) } // send out the packets one by one for len(p) > 0 { @@ -72,7 +84,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { if l > int(gsoSize) { l = int(gsoSize) } - if err := c.writePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { + if err := c.writePacket(p[:l], ai.addr, ai.oob, 0, ecn); err != nil { return err } p = p[l:] @@ -91,6 +103,11 @@ func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16, return err } +func (c *sconn) WriteTo(b []byte, addr net.Addr) error { + _, err := c.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) + return err +} + func (c *sconn) capabilities() connCapabilities { capabilities := c.rawConn.capabilities() if capabilities.GSO { @@ -99,5 +116,12 @@ func (c *sconn) capabilities() connCapabilities { return capabilities } -func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *sconn) ChangeRemoteAddr(addr net.Addr, info packetInfo) { + c.remoteAddrInfo.Store(&remoteAddrInfo{ + addr: addr, + oob: info.OOB(), + }) +} + +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddrInfo.Load().addr } func (c *sconn) LocalAddr() net.Addr { return c.localAddr } diff --git a/send_conn_test.go b/send_conn_test.go index 017c3f7a..0df2a81b 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "runtime" "testing" + "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" @@ -98,3 +99,37 @@ func TestSendConnSendmsgFailures(t *testing.T) { require.Error(t, c.Write([]byte("foobar"), 0, protocol.ECNCE)) }) } + +func TestSendConnRemoteAddrChange(t *testing.T) { + ln1 := newUPDConnLocalhost(t) + ln2 := newUPDConnLocalhost(t) + + c := newSendConn( + &basicConn{PacketConn: newUPDConnLocalhost(t)}, + ln1.LocalAddr(), + packetInfo{}, + utils.DefaultLogger, + ) + + require.NoError(t, c.Write([]byte("foobar"), 0, protocol.ECNUnsupported)) + ln1.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 1024) + n, err := ln1.Read(b) + require.NoError(t, err) + require.Equal(t, "foobar", string(b[:n])) + + require.NoError(t, c.WriteTo([]byte("foobaz"), ln2.LocalAddr())) + ln2.SetReadDeadline(time.Now().Add(time.Second)) + b = make([]byte, 1024) + n, err = ln2.Read(b) + require.NoError(t, err) + require.Equal(t, "foobaz", string(b[:n])) + + c.ChangeRemoteAddr(ln2.LocalAddr(), packetInfo{}) + require.NoError(t, c.Write([]byte("lorem ipsum"), 0, protocol.ECNUnsupported)) + ln2.SetReadDeadline(time.Now().Add(time.Second)) + b = make([]byte, 1024) + n, err = ln2.Read(b) + require.NoError(t, err) + require.Equal(t, "lorem ipsum", string(b[:n])) +}