From 600502ab06f693ca1f59e6925fdde0cb302e5a8c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 8 May 2023 13:04:18 +0300 Subject: [PATCH] simplify connection handling when setting the receive buffer --- sys_conn_buffers.go | 29 ++++++++++++++++++++++++----- sys_conn_helper_linux.go | 16 ++-------------- sys_conn_helper_linux_test.go | 11 +++++++---- sys_conn_no_oob.go | 4 +--- sys_conn_oob.go | 15 ++------------- sys_conn_windows.go | 17 ++--------------- 6 files changed, 38 insertions(+), 54 deletions(-) diff --git a/sys_conn_buffers.go b/sys_conn_buffers.go index 2d60ab10..5d3b10b1 100644 --- a/sys_conn_buffers.go +++ b/sys_conn_buffers.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" "net" - "sync" + "syscall" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" @@ -15,7 +15,26 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { if !ok { return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") } - size, err := inspectReadBuffer(c) + + var syscallConn syscall.RawConn + if sc, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }); ok { + var err error + syscallConn, err = sc.SyscallConn() + if err != nil { + syscallConn = nil + } + } + // The connection has a SetReadBuffer method, but we couldn't obtain a syscall.RawConn. + // This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the + // net.PacketConn interface and the SetReadBuffer method. + // We have no way of checking if increasing the buffer size actually worked. + if syscallConn == nil { + return conn.SetReadBuffer(protocol.DesiredReceiveBufferSize) + } + + size, err := inspectReadBuffer(syscallConn) if err != nil { return fmt.Errorf("failed to determine receive buffer size: %w", err) } @@ -25,11 +44,11 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { } // Ignore the error. We check if we succeeded by querying the buffer size afterward. _ = conn.SetReadBuffer(protocol.DesiredReceiveBufferSize) - newSize, err := inspectReadBuffer(c) + newSize, err := inspectReadBuffer(syscallConn) if newSize < protocol.DesiredReceiveBufferSize { // Try again with RCVBUFFORCE on Linux - _ = forceSetReceiveBuffer(c, protocol.DesiredReceiveBufferSize) - newSize, err = inspectReadBuffer(c) + _ = forceSetReceiveBuffer(syscallConn, protocol.DesiredReceiveBufferSize) + newSize, err = inspectReadBuffer(syscallConn) if err != nil { return fmt.Errorf("failed to determine receive buffer size: %w", err) } diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index 721b38ea..69e1a2d5 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -3,8 +3,6 @@ package quic import ( - "errors" - "fmt" "syscall" "golang.org/x/sys/unix" @@ -24,19 +22,9 @@ const ( const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) -func forceSetReceiveBuffer(c interface{}, bytes int) error { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } +func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { var serr error - if err := rawConn.Control(func(fd uintptr) { + if err := c.Control(func(fd uintptr) { serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes) }); err != nil { return err diff --git a/sys_conn_helper_linux_test.go b/sys_conn_helper_linux_test.go index c2e70ef1..dd095f95 100644 --- a/sys_conn_helper_linux_test.go +++ b/sys_conn_helper_linux_test.go @@ -21,15 +21,18 @@ var _ = Describe("Can change the receive buffer size", func() { c, err := net.ListenPacket("udp", "127.0.0.1:0") Expect(err).ToNot(HaveOccurred()) - forceSetReceiveBuffer(c, 256<<10) + defer c.Close() + syscallConn, err := c.(*net.UDPConn).SyscallConn() + Expect(err).ToNot(HaveOccurred()) + forceSetReceiveBuffer(syscallConn, 256<<10) - size, err := inspectReadBuffer(c) + size, err := inspectReadBuffer(syscallConn) Expect(err).ToNot(HaveOccurred()) // The kernel doubles this value (to allow space for bookkeeping overhead) Expect(size).To(Equal(512 << 10)) - forceSetReceiveBuffer(c, 512<<10) - size, err = inspectReadBuffer(c) + forceSetReceiveBuffer(syscallConn, 512<<10) + size, err = inspectReadBuffer(syscallConn) Expect(err).ToNot(HaveOccurred()) // The kernel doubles this value (to allow space for bookkeeping overhead) Expect(size).To(Equal(1024 << 10)) diff --git a/sys_conn_no_oob.go b/sys_conn_no_oob.go index 7ab5040a..e5189b18 100644 --- a/sys_conn_no_oob.go +++ b/sys_conn_no_oob.go @@ -8,8 +8,6 @@ func newConn(c net.PacketConn) (rawConn, error) { return &basicConn{PacketConn: c}, nil } -func inspectReadBuffer(interface{}) (int, error) { - return 0, nil -} +func inspectReadBuffer(any) (int, error) { return 0, nil } func (i *packetInfo) OOB() []byte { return nil } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 806dfb81..bc833efa 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -5,7 +5,6 @@ package quic import ( "encoding/binary" "errors" - "fmt" "net" "syscall" "time" @@ -32,20 +31,10 @@ type batchConn interface { ReadBatch(ms []ipv4.Message, flags int) (int, error) } -func inspectReadBuffer(c interface{}) (int, error) { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return 0, errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } +func inspectReadBuffer(c syscall.RawConn) (int, error) { var size int var serr error - if err := rawConn.Control(func(fd uintptr) { + if err := c.Control(func(fd uintptr) { size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) }); err != nil { return 0, err diff --git a/sys_conn_windows.go b/sys_conn_windows.go index b003fe94..205abd14 100644 --- a/sys_conn_windows.go +++ b/sys_conn_windows.go @@ -3,9 +3,6 @@ package quic import ( - "errors" - "fmt" - "net" "syscall" "golang.org/x/sys/windows" @@ -15,20 +12,10 @@ func newConn(c OOBCapablePacketConn) (rawConn, error) { return &basicConn{PacketConn: c}, nil } -func inspectReadBuffer(c net.PacketConn) (int, error) { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return 0, errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } +func inspectReadBuffer(c syscall.RawConn) (int, error) { var size int var serr error - if err := rawConn.Control(func(fd uintptr) { + if err := c.Control(func(fd uintptr) { size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) }); err != nil { return 0, err