diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go new file mode 100644 index 00000000..66b9c611 --- /dev/null +++ b/mock_raw_conn_test.go @@ -0,0 +1,122 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go (interfaces: RawConn) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" +) + +// MockRawConn is a mock of RawConn interface. +type MockRawConn struct { + ctrl *gomock.Controller + recorder *MockRawConnMockRecorder +} + +// MockRawConnMockRecorder is the mock recorder for MockRawConn. +type MockRawConnMockRecorder struct { + mock *MockRawConn +} + +// NewMockRawConn creates a new mock instance. +func NewMockRawConn(ctrl *gomock.Controller) *MockRawConn { + mock := &MockRawConn{ctrl: ctrl} + mock.recorder = &MockRawConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRawConn) EXPECT() *MockRawConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRawConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRawConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRawConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockRawConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockRawConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockRawConn)(nil).LocalAddr)) +} + +// ReadPacket mocks base method. +func (m *MockRawConn) ReadPacket() (receivedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadPacket") + ret0, _ := ret[0].(receivedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadPacket indicates an expected call of ReadPacket. +func (mr *MockRawConnMockRecorder) ReadPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadPacket", reflect.TypeOf((*MockRawConn)(nil).ReadPacket)) +} + +// SetReadDeadline mocks base method. +func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) +} + +// WritePacket mocks base method. +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WritePacket indicates an expected call of WritePacket. +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2) +} + +// capabilities mocks base method. +func (m *MockRawConn) capabilities() connCapabilities { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "capabilities") + ret0, _ := ret[0].(connCapabilities) + return ret0 +} + +// capabilities indicates an expected call of capabilities. +func (mr *MockRawConnMockRecorder) capabilities() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockRawConn)(nil).capabilities)) +} diff --git a/mockgen.go b/mockgen.go index eb700864..221c1367 100644 --- a/mockgen.go +++ b/mockgen.go @@ -5,6 +5,9 @@ package quic //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" type SendConn = sendConn +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn" +type RawConn = rawConn + //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" type Sender = sender diff --git a/packet_handler_map.go b/packet_handler_map.go index 2a16773c..e0f0567d 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -26,9 +26,9 @@ type connCapabilities struct { // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (receivedPacket, error) - // The size parameter is used for GSO. - // If GSO is not support, len(b) must be equal to size. - WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) + // WritePacket writes a packet on the wire. + // If GSO is enabled, it's the caller's responsibility to set the correct control message. + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/send_conn.go b/send_conn.go index 4e7007fa..34cbfd6e 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,10 +1,12 @@ package quic import ( + "fmt" "math" "net" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. @@ -20,61 +22,84 @@ type sendConn interface { type sconn struct { rawConn + localAddr net.Addr remoteAddr net.Addr - info packetInfo - oob []byte + + logger utils.Logger + + info packetInfo + oob []byte + // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. + gotGSOError bool } var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr) *sconn { - sc := &sconn{ - rawConn: c, - remoteAddr: remote, +func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logger) *sconn { + localAddr := c.LocalAddr() + if info.addr.IsValid() { + if udpAddr, ok := localAddr.(*net.UDPAddr); ok { + addrCopy := *udpAddr + addrCopy.IP = info.addr.AsSlice() + localAddr = &addrCopy + } } - if c.capabilities().GSO { - // add 32 bytes, so we can add the UDP_SEGMENT msg - sc.oob = make([]byte, 0, 32) - } - return sc -} -func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn { oob := info.OOB() - if c.capabilities().GSO { - // add 32 bytes, so we can add the UDP_SEGMENT msg - l := len(oob) - oob = append(oob, make([]byte, 32)...) - oob = oob[:l] - } + // add 32 bytes, so we can add the UDP_SEGMENT msg + l := len(oob) + oob = append(oob, make([]byte, 32)...) + oob = oob[:l] return &sconn{ rawConn: c, + localAddr: localAddr, remoteAddr: remote, info: info, oob: oob, + logger: logger, } } func (c *sconn) Write(p []byte, size protocol.ByteCount) error { + if !c.capabilities().GSO { + if protocol.ByteCount(len(p)) != size { + panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) + } + _, err := c.WritePacket(p, c.remoteAddr, c.oob) + return err + } + // GSO is supported. Append the control message and send. if size > math.MaxUint16 { panic("size overflow") } - _, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob) + _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) + if err != nil && isGSOError(err) { + // disable GSO for future calls + c.gotGSOError = true + if c.logger.Debug() { + c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr) + } + // send out the packets one by one + for len(p) > 0 { + l := len(p) + if l > int(size) { + l = int(size) + } + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + return err + } + p = p[l:] + } + return nil + } return err } -func (c *sconn) RemoteAddr() net.Addr { - return c.remoteAddr +func (c *sconn) capabilities() connCapabilities { + capabilities := c.rawConn.capabilities() + capabilities.GSO = !c.gotGSOError + return capabilities } -func (c *sconn) LocalAddr() net.Addr { - addr := c.rawConn.LocalAddr() - if c.info.addr.IsValid() { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - addrCopy := *udpAddr - addrCopy.IP = c.info.addr.AsSlice() - addr = &addrCopy - } - } - return addr -} +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *sconn) LocalAddr() net.Addr { return c.localAddr } diff --git a/send_conn_test.go b/send_conn_test.go index 56fe9236..8676d409 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -2,46 +2,81 @@ package quic import ( "net" + "net/netip" + "github.com/quic-go/quic-go/internal/utils" + + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +// Only if appendUDPSegmentSizeMsg actually appends a message (and isn't only a stub implementation), +// GSO is actually supported on this platform. +var platformSupportsGSO = len(appendUDPSegmentSizeMsg([]byte{}, 1337)) > 0 + var _ = Describe("Connection (for sending packets)", func() { - var ( - c sendConn - packetConn *MockPacketConn - addr net.Addr - ) + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - BeforeEach(func() { - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) - rawConn, err := wrapConn(packetConn) - Expect(err).ToNot(HaveOccurred()) - c = newSendConnWithPacketInfo(rawConn, addr, packetInfo{}) - }) - - It("writes", func() { - packetConn.EXPECT().WriteTo([]byte("foobar"), addr) - Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) - }) - - It("gets the remote address", func() { + It("gets the local and remote addresses", func() { + localAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1234} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr().Return(localAddr) + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + Expect(c.LocalAddr().String()).To(Equal("192.168.0.1:1234")) Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) }) - It("gets the local address", func() { - addr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 0, 1), - Port: 1234, - } - packetConn.EXPECT().LocalAddr().Return(addr) - Expect(c.LocalAddr()).To(Equal(addr)) + It("uses the local address from the packet info", func() { + localAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1234} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr().Return(localAddr) + c := newSendConn(rawConn, remoteAddr, packetInfo{addr: netip.AddrFrom4([4]byte{127, 0, 0, 42})}, utils.DefaultLogger) + Expect(c.LocalAddr().String()).To(Equal("127.0.0.42:1234")) }) - It("closes", func() { - packetConn.EXPECT().Close() - Expect(c.Close()).To(Succeed()) - }) + if platformSupportsGSO { + It("writes with GSO", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).Do(func(_ []byte, _ net.Addr, oob []byte) { + msg := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oob).To(Equal(msg)) + }) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + }) + + It("disables GSO if writing fails", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + Expect(c.capabilities().GSO).To(BeTrue()) + gomock.InOrder( + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).DoAndReturn(func(_ []byte, _ net.Addr, oob []byte) (int, error) { + msg := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oob).To(Equal(msg)) + return 0, errGSO + }), + rawConn.EXPECT().WritePacket([]byte("foo"), remoteAddr, []byte{}).Return(3, nil), + rawConn.EXPECT().WritePacket([]byte("bar"), remoteAddr, []byte{}).Return(3, nil), + ) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + Expect(c.capabilities().GSO).To(BeFalse()) // GSO support is now disabled + // make sure we actually enforce that + Expect(func() { c.Write([]byte("foobar"), 3) }).To(PanicWith("inconsistent packet size (6 vs 3)")) + }) + } else { + It("writes without GSO", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, nil) + Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) + }) + } }) diff --git a/server.go b/server.go index 0f8219e3..c06228c9 100644 --- a/server.go +++ b/server.go @@ -632,7 +632,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) } conn = s.newConn( - newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info), + newSendConn(s.conn, p.remoteAddr, p.info, s.logger), s.connHandler, origDestConnID, retrySrcConnID, @@ -742,7 +742,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) return err } @@ -841,7 +841,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) return err } @@ -879,7 +879,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index 414472e7..f2224e4c 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -1,7 +1,6 @@ package quic import ( - "fmt" "log" "net" "os" @@ -105,10 +104,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) { - if uint16(len(b)) != packetSize { - panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) - } +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_df_linux.go b/sys_conn_df_linux.go index 199f6347..f09eaa5d 100644 --- a/sys_conn_df_linux.go +++ b/sys_conn_df_linux.go @@ -4,11 +4,7 @@ package quic import ( "errors" - "log" - "os" - "strconv" "syscall" - "unsafe" "golang.org/x/sys/unix" @@ -38,43 +34,9 @@ func setDF(rawConn syscall.RawConn) (bool, error) { return true, nil } -func maybeSetGSO(rawConn syscall.RawConn) bool { - enable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_ENABLE_GSO")) - if !enable { - return false - } - - var setErr error - if err := rawConn.Control(func(fd uintptr) { - setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, unix.UDP_SEGMENT, 1) - }); err != nil { - setErr = err - } - if setErr != nil { - log.Println("failed to enable GSO") - return false - } - return true -} - func isSendMsgSizeErr(err error) bool { // https://man7.org/linux/man-pages/man7/udp.7.html return errors.Is(err, unix.EMSGSIZE) } -func isRecvMsgSizeErr(err error) bool { return false } - -func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { - startLen := len(b) - const dataLen = 2 // payload is a uint16 - b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) - h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) - h.Level = syscall.IPPROTO_UDP - h.Type = unix.UDP_SEGMENT - h.SetLen(unix.CmsgLen(dataLen)) - - // UnixRights uses the private `data` method, but I *think* this achieves the same goal. - offset := startLen + unix.CmsgSpace(0) - *(*uint16)(unsafe.Pointer(&b[offset])) = size - return b -} +func isRecvMsgSizeErr(error) bool { return false } diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index 61224eaa..4e87bba0 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -4,8 +4,11 @@ package quic import ( "encoding/binary" + "errors" "net/netip" + "os" "syscall" + "unsafe" "golang.org/x/sys/unix" ) @@ -48,3 +51,30 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true } + +func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { + startLen := len(b) + const dataLen = 2 // payload is a uint16 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_UDP + h.Type = unix.UDP_SEGMENT + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + *(*uint16)(unsafe.Pointer(&b[offset])) = size + return b +} + +func isGSOError(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have tx checksums enabled, + // which is a hard requirement of UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/sys_conn_helper_linux_test.go b/sys_conn_helper_linux_test.go index fa39c523..4cf59abe 100644 --- a/sys_conn_helper_linux_test.go +++ b/sys_conn_helper_linux_test.go @@ -1,22 +1,24 @@ -// We need root permissions to use RCVBUFFORCE. -// This test is therefore only compiled when the root build flag is set. -// It can only succeed if the tests are then also run with root permissions. -//go:build linux && root +//go:build linux package quic import ( + "errors" "net" "os" + "golang.org/x/sys/unix" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +var errGSO = &os.SyscallError{Err: unix.EIO} + var _ = Describe("forcing a change of send and receive buffer sizes", func() { It("forces a change of the receive buffer size", func() { if os.Getuid() != 0 { - Fail("Must be root to force change the receive buffer size") + Skip("Must be root to force change the receive buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -43,7 +45,7 @@ var _ = Describe("forcing a change of send and receive buffer sizes", func() { It("forces a change of the send buffer size", func() { if os.Getuid() != 0 { - Fail("Must be root to force change the send buffer size") + Skip("Must be root to force change the send buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -67,4 +69,10 @@ var _ = Describe("forcing a change of send and receive buffer sizes", func() { // The kernel doubles this value (to allow space for bookkeeping overhead) Expect(size).To(Equal(2 * large)) }) + + It("detects GSO errors", func() { + Expect(isGSOError(errGSO)).To(BeTrue()) + Expect(isGSOError(nil)).To(BeFalse()) + Expect(isGSOError(errors.New("test"))).To(BeFalse()) + }) }) diff --git a/sys_conn_helper_nonlinux.go b/sys_conn_helper_nonlinux.go index 80b795c3..48ab10aa 100644 --- a/sys_conn_helper_nonlinux.go +++ b/sys_conn_helper_nonlinux.go @@ -4,3 +4,6 @@ package quic func forceSetReceiveBuffer(c any, bytes int) error { return nil } func forceSetSendBuffer(c any, bytes int) error { return nil } + +func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil } +func isGSOError(error) bool { return false } diff --git a/sys_conn_helper_nonlinux_test.go b/sys_conn_helper_nonlinux_test.go new file mode 100644 index 00000000..29d42ad3 --- /dev/null +++ b/sys_conn_helper_nonlinux_test.go @@ -0,0 +1,7 @@ +//go:build !linux + +package quic + +import "errors" + +var errGSO = errors.New("fake GSO error") diff --git a/sys_conn_no_gso.go b/sys_conn_no_gso.go deleted file mode 100644 index 6f6a8c91..00000000 --- a/sys_conn_no_gso.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build darwin || freebsd - -package quic - -import "syscall" - -func maybeSetGSO(_ syscall.RawConn) bool { return false } -func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 84d5e7e6..aa69262b 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -5,7 +5,6 @@ package quic import ( "encoding/binary" "errors" - "fmt" "log" "net" "net/netip" @@ -128,10 +127,6 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { bc = ipv4.NewPacketConn(c) } - // Try enabling GSO. - // This will only succeed on Linux, and only for kernels > 4.18. - supportsGSO := maybeSetGSO(rawConn) - msgs := make([]ipv4.Message, batchSize) for i := range msgs { // preallocate the [][]byte @@ -144,7 +139,6 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { readPos: batchSize, } oobConn.cap.DF = supportsDF - oobConn.cap.GSO = supportsGSO for i := 0; i < batchSize; i++ { oobConn.messages[i].OOB = make([]byte, oobBufferSize) } @@ -231,17 +225,9 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO (and we activated GSO support before), -// it appends the UDP_SEGMENT size message to oob. -// Callers are advised to make sure that oob has a sufficient capacity, -// such that appending the UDP_SEGMENT size message doesn't cause an allocation. -func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) { - if c.cap.GSO { - oob = appendUDPSegmentSizeMsg(oob, packetSize) - } else if uint16(len(b)) != packetSize { - panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) - } - n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) +// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. +func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { + n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/transport.go b/transport.go index 42fafd49..ae44e3da 100644 --- a/transport.go +++ b/transport.go @@ -179,7 +179,7 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, hostname string, tl } tlsConf.ServerName = hostname } - return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) + return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) } func (t *Transport) init(allowZeroLengthConnIDs bool) error { @@ -195,7 +195,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { return } } - t.conn = conn t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn @@ -229,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, uint16(len(b)), addr, nil) + return t.conn.WritePacket(b, addr, nil) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -247,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -408,7 +407,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } }