mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
add a context to Connection.ReceiveMessage (#3926)
* add context to ReceiveMessage * add newlines --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
parent
f3875147b9
commit
435444af7e
7 changed files with 35 additions and 17 deletions
|
@ -2302,11 +2302,11 @@ func (s *connection) SendMessage(p []byte) error {
|
|||
return s.datagramQueue.AddAndWait(f)
|
||||
}
|
||||
|
||||
func (s *connection) ReceiveMessage() ([]byte, error) {
|
||||
func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) {
|
||||
if !s.config.EnableDatagrams {
|
||||
return nil, errors.New("datagram support disabled")
|
||||
}
|
||||
return s.datagramQueue.Receive()
|
||||
return s.datagramQueue.Receive(ctx)
|
||||
}
|
||||
|
||||
func (s *connection) LocalAddr() net.Addr {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -98,7 +99,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
|||
}
|
||||
|
||||
// Receive gets a received DATAGRAM frame.
|
||||
func (h *datagramQueue) Receive() ([]byte, error) {
|
||||
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
|
||||
for {
|
||||
h.rcvMx.Lock()
|
||||
if len(h.rcvQueue) > 0 {
|
||||
|
@ -113,6 +114,8 @@ func (h *datagramQueue) Receive() ([]byte, error) {
|
|||
continue
|
||||
case <-h.closed:
|
||||
return nil, h.closeErr
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
|
@ -81,10 +82,10 @@ var _ = Describe("Datagram Queue", func() {
|
|||
It("receives DATAGRAM frames", func() {
|
||||
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")})
|
||||
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")})
|
||||
data, err := queue.Receive()
|
||||
data, err := queue.Receive(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foo")))
|
||||
data, err = queue.Receive()
|
||||
data, err = queue.Receive(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("bar")))
|
||||
})
|
||||
|
@ -93,7 +94,7 @@ var _ = Describe("Datagram Queue", func() {
|
|||
c := make(chan []byte, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
data, err := queue.Receive()
|
||||
data, err := queue.Receive(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c <- data
|
||||
}()
|
||||
|
@ -103,11 +104,25 @@ var _ = Describe("Datagram Queue", func() {
|
|||
Eventually(c).Should(Receive(Equal([]byte("foobar"))))
|
||||
})
|
||||
|
||||
It("blocks until context is done", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := queue.Receive(ctx)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
Consistently(errChan).ShouldNot(Receive())
|
||||
cancel()
|
||||
Eventually(errChan).Should(Receive(Equal(context.Canceled)))
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := queue.Receive()
|
||||
_, err := queue.Receive(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ var _ = Describe("Datagram test", func() {
|
|||
for {
|
||||
// Close the connection if no message is received for 100 ms.
|
||||
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") })
|
||||
if _, err := conn.ReceiveMessage(); err != nil {
|
||||
if _, err := conn.ReceiveMessage(context.Background()); err != nil {
|
||||
break
|
||||
}
|
||||
timer.Stop()
|
||||
|
|
|
@ -185,7 +185,7 @@ type Connection interface {
|
|||
// SendMessage sends a message as a datagram, as specified in RFC 9221.
|
||||
SendMessage([]byte) error
|
||||
// ReceiveMessage gets a message received in a datagram, as specified in RFC 9221.
|
||||
ReceiveMessage() ([]byte, error)
|
||||
ReceiveMessage(context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
// An EarlyConnection is a connection that is handshaking.
|
||||
|
|
|
@ -212,18 +212,18 @@ func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *
|
|||
}
|
||||
|
||||
// ReceiveMessage mocks base method.
|
||||
func (m *MockEarlyConnection) ReceiveMessage() ([]byte, error) {
|
||||
func (m *MockEarlyConnection) ReceiveMessage(arg0 context.Context) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ReceiveMessage")
|
||||
ret := m.ctrl.Call(m, "ReceiveMessage", arg0)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ReceiveMessage indicates an expected call of ReceiveMessage.
|
||||
func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage() *gomock.Call {
|
||||
func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage), arg0)
|
||||
}
|
||||
|
||||
// RemoteAddr mocks base method.
|
||||
|
|
|
@ -226,18 +226,18 @@ func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.
|
|||
}
|
||||
|
||||
// ReceiveMessage mocks base method.
|
||||
func (m *MockQUICConn) ReceiveMessage() ([]byte, error) {
|
||||
func (m *MockQUICConn) ReceiveMessage(arg0 context.Context) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ReceiveMessage")
|
||||
ret := m.ctrl.Call(m, "ReceiveMessage", arg0)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ReceiveMessage indicates an expected call of ReceiveMessage.
|
||||
func (mr *MockQUICConnMockRecorder) ReceiveMessage() *gomock.Call {
|
||||
func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage), arg0)
|
||||
}
|
||||
|
||||
// RemoteAddr mocks base method.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue