add a context to Connection.ReceiveMessage (#3926)

* add context to ReceiveMessage

* add newlines

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Glonee 2023-06-28 02:29:30 +08:00 committed by GitHub
parent f3875147b9
commit 435444af7e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 35 additions and 17 deletions

View file

@ -2302,11 +2302,11 @@ func (s *connection) SendMessage(p []byte) error {
return s.datagramQueue.AddAndWait(f) return s.datagramQueue.AddAndWait(f)
} }
func (s *connection) ReceiveMessage() ([]byte, error) { func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) {
if !s.config.EnableDatagrams { if !s.config.EnableDatagrams {
return nil, errors.New("datagram support disabled") return nil, errors.New("datagram support disabled")
} }
return s.datagramQueue.Receive() return s.datagramQueue.Receive(ctx)
} }
func (s *connection) LocalAddr() net.Addr { func (s *connection) LocalAddr() net.Addr {

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"sync" "sync"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -98,7 +99,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
} }
// Receive gets a received DATAGRAM frame. // Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive() ([]byte, error) { func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
for { for {
h.rcvMx.Lock() h.rcvMx.Lock()
if len(h.rcvQueue) > 0 { if len(h.rcvQueue) > 0 {
@ -113,6 +114,8 @@ func (h *datagramQueue) Receive() ([]byte, error) {
continue continue
case <-h.closed: case <-h.closed:
return nil, h.closeErr return nil, h.closeErr
case <-ctx.Done():
return nil, ctx.Err()
} }
} }
} }

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"errors" "errors"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
@ -81,10 +82,10 @@ var _ = Describe("Datagram Queue", func() {
It("receives DATAGRAM frames", func() { It("receives DATAGRAM frames", func() {
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")})
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")})
data, err := queue.Receive() data, err := queue.Receive(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foo"))) Expect(data).To(Equal([]byte("foo")))
data, err = queue.Receive() data, err = queue.Receive(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("bar"))) Expect(data).To(Equal([]byte("bar")))
}) })
@ -93,7 +94,7 @@ var _ = Describe("Datagram Queue", func() {
c := make(chan []byte, 1) c := make(chan []byte, 1)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
data, err := queue.Receive() data, err := queue.Receive(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
c <- data c <- data
}() }()
@ -103,11 +104,25 @@ var _ = Describe("Datagram Queue", func() {
Eventually(c).Should(Receive(Equal([]byte("foobar")))) Eventually(c).Should(Receive(Equal([]byte("foobar"))))
}) })
It("blocks until context is done", func() {
ctx, cancel := context.WithCancel(context.Background())
errChan := make(chan error)
go func() {
defer GinkgoRecover()
_, err := queue.Receive(ctx)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(Equal(context.Canceled)))
})
It("closes", func() { It("closes", func() {
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := queue.Receive() _, err := queue.Receive(context.Background())
errChan <- err errChan <- err
}() }()

View file

@ -120,7 +120,7 @@ var _ = Describe("Datagram test", func() {
for { for {
// Close the connection if no message is received for 100 ms. // Close the connection if no message is received for 100 ms.
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") }) timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") })
if _, err := conn.ReceiveMessage(); err != nil { if _, err := conn.ReceiveMessage(context.Background()); err != nil {
break break
} }
timer.Stop() timer.Stop()

View file

@ -185,7 +185,7 @@ type Connection interface {
// SendMessage sends a message as a datagram, as specified in RFC 9221. // SendMessage sends a message as a datagram, as specified in RFC 9221.
SendMessage([]byte) error SendMessage([]byte) error
// ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221.
ReceiveMessage() ([]byte, error) ReceiveMessage(context.Context) ([]byte, error)
} }
// An EarlyConnection is a connection that is handshaking. // An EarlyConnection is a connection that is handshaking.

View file

@ -212,18 +212,18 @@ func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *
} }
// ReceiveMessage mocks base method. // ReceiveMessage mocks base method.
func (m *MockEarlyConnection) ReceiveMessage() ([]byte, error) { func (m *MockEarlyConnection) ReceiveMessage(arg0 context.Context) ([]byte, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReceiveMessage") ret := m.ctrl.Call(m, "ReceiveMessage", arg0)
ret0, _ := ret[0].([]byte) ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// ReceiveMessage indicates an expected call of ReceiveMessage. // ReceiveMessage indicates an expected call of ReceiveMessage.
func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage() *gomock.Call { func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage), arg0)
} }
// RemoteAddr mocks base method. // RemoteAddr mocks base method.

View file

@ -226,18 +226,18 @@ func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.
} }
// ReceiveMessage mocks base method. // ReceiveMessage mocks base method.
func (m *MockQUICConn) ReceiveMessage() ([]byte, error) { func (m *MockQUICConn) ReceiveMessage(arg0 context.Context) ([]byte, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReceiveMessage") ret := m.ctrl.Call(m, "ReceiveMessage", arg0)
ret0, _ := ret[0].([]byte) ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// ReceiveMessage indicates an expected call of ReceiveMessage. // ReceiveMessage indicates an expected call of ReceiveMessage.
func (mr *MockQUICConnMockRecorder) ReceiveMessage() *gomock.Call { func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage), arg0)
} }
// RemoteAddr mocks base method. // RemoteAddr mocks base method.