From 435444af7eb81b922a412ee80d4caddff26744de Mon Sep 17 00:00:00 2001 From: Glonee Date: Wed, 28 Jun 2023 02:29:30 +0800 Subject: [PATCH] add a context to Connection.ReceiveMessage (#3926) * add context to ReceiveMessage * add newlines --------- Co-authored-by: Marten Seemann --- connection.go | 4 ++-- datagram_queue.go | 5 ++++- datagram_queue_test.go | 23 +++++++++++++++++++---- integrationtests/self/datagram_test.go | 2 +- interface.go | 2 +- internal/mocks/quic/early_conn.go | 8 ++++---- mock_quic_conn_test.go | 8 ++++---- 7 files changed, 35 insertions(+), 17 deletions(-) diff --git a/connection.go b/connection.go index 433adeda..080f45ca 100644 --- a/connection.go +++ b/connection.go @@ -2302,11 +2302,11 @@ func (s *connection) SendMessage(p []byte) error { return s.datagramQueue.AddAndWait(f) } -func (s *connection) ReceiveMessage() ([]byte, error) { +func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) { if !s.config.EnableDatagrams { return nil, errors.New("datagram support disabled") } - return s.datagramQueue.Receive() + return s.datagramQueue.Receive(ctx) } func (s *connection) LocalAddr() net.Addr { diff --git a/datagram_queue.go b/datagram_queue.go index 59c7d069..ca80d404 100644 --- a/datagram_queue.go +++ b/datagram_queue.go @@ -1,6 +1,7 @@ package quic import ( + "context" "sync" "github.com/quic-go/quic-go/internal/protocol" @@ -98,7 +99,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { } // Receive gets a received DATAGRAM frame. -func (h *datagramQueue) Receive() ([]byte, error) { +func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) { for { h.rcvMx.Lock() if len(h.rcvQueue) > 0 { @@ -113,6 +114,8 @@ func (h *datagramQueue) Receive() ([]byte, error) { continue case <-h.closed: return nil, h.closeErr + case <-ctx.Done(): + return nil, ctx.Err() } } } diff --git a/datagram_queue_test.go b/datagram_queue_test.go index a18990df..de3f8f57 100644 --- a/datagram_queue_test.go +++ b/datagram_queue_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "github.com/quic-go/quic-go/internal/utils" @@ -81,10 +82,10 @@ var _ = Describe("Datagram Queue", func() { It("receives DATAGRAM frames", func() { queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) - data, err := queue.Receive() + data, err := queue.Receive(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal([]byte("foo"))) - data, err = queue.Receive() + data, err = queue.Receive(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal([]byte("bar"))) }) @@ -93,7 +94,7 @@ var _ = Describe("Datagram Queue", func() { c := make(chan []byte, 1) go func() { defer GinkgoRecover() - data, err := queue.Receive() + data, err := queue.Receive(context.Background()) Expect(err).ToNot(HaveOccurred()) c <- data }() @@ -103,11 +104,25 @@ var _ = Describe("Datagram Queue", func() { Eventually(c).Should(Receive(Equal([]byte("foobar")))) }) + It("blocks until context is done", func() { + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error) + go func() { + defer GinkgoRecover() + _, err := queue.Receive(ctx) + errChan <- err + }() + + Consistently(errChan).ShouldNot(Receive()) + cancel() + Eventually(errChan).Should(Receive(Equal(context.Canceled))) + }) + It("closes", func() { errChan := make(chan error, 1) go func() { defer GinkgoRecover() - _, err := queue.Receive() + _, err := queue.Receive(context.Background()) errChan <- err }() diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 65cb13fa..35d0718a 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -120,7 +120,7 @@ var _ = Describe("Datagram test", func() { for { // Close the connection if no message is received for 100 ms. timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") }) - if _, err := conn.ReceiveMessage(); err != nil { + if _, err := conn.ReceiveMessage(context.Background()); err != nil { break } timer.Stop() diff --git a/interface.go b/interface.go index 8e6213bf..436e250e 100644 --- a/interface.go +++ b/interface.go @@ -185,7 +185,7 @@ type Connection interface { // SendMessage sends a message as a datagram, as specified in RFC 9221. SendMessage([]byte) error // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. - ReceiveMessage() ([]byte, error) + ReceiveMessage(context.Context) ([]byte, error) } // An EarlyConnection is a connection that is handshaking. diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index 174c70de..a573e06f 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -212,18 +212,18 @@ func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) * } // ReceiveMessage mocks base method. -func (m *MockEarlyConnection) ReceiveMessage() ([]byte, error) { +func (m *MockEarlyConnection) ReceiveMessage(arg0 context.Context) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage") + ret := m.ctrl.Call(m, "ReceiveMessage", arg0) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage), arg0) } // RemoteAddr mocks base method. diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index bebc1c27..18932051 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -226,18 +226,18 @@ func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock. } // ReceiveMessage mocks base method. -func (m *MockQUICConn) ReceiveMessage() ([]byte, error) { +func (m *MockQUICConn) ReceiveMessage(arg0 context.Context) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage") + ret := m.ctrl.Call(m, "ReceiveMessage", arg0) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockQUICConnMockRecorder) ReceiveMessage() *gomock.Call { +func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage), arg0) } // RemoteAddr mocks base method.