diff --git a/datagram_queue.go b/datagram_queue.go index eb32fe7c..92b5c3b0 100644 --- a/datagram_queue.go +++ b/datagram_queue.go @@ -1,50 +1,76 @@ package quic import ( - "sync" - + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) type datagramQueue struct { - mutex sync.Mutex - queue chan *wire.DatagramFrame + sendQueue chan *wire.DatagramFrame + rcvQueue chan []byte closeErr error closed chan struct{} hasData func() + + logger utils.Logger } -func newDatagramQueue(hasData func()) *datagramQueue { +func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { return &datagramQueue{ - queue: make(chan *wire.DatagramFrame), - hasData: hasData, - closed: make(chan struct{}), + hasData: hasData, + sendQueue: make(chan *wire.DatagramFrame), + rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), + closed: make(chan struct{}), + logger: logger, } } -// AddAndWait queues a new DATAGRAM frame. +// AddAndWait queues a new DATAGRAM frame for sending. // It blocks until the frame has been dequeued. func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { h.hasData() select { - case h.queue <- f: + case h.sendQueue <- f: return nil case <-h.closed: return h.closeErr } } +// Get dequeues a DATAGRAM frame for sending. func (h *datagramQueue) Get() *wire.DatagramFrame { select { - case f := <-h.queue: + case f := <-h.sendQueue: return f default: return nil } } +// HandleDatagramFrame handles a received DATAGRAM frame. +func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { + data := make([]byte, len(f.Data)) + copy(data, f.Data) + select { + case h.rcvQueue <- data: + default: + h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + } +} + +// Receive gets a received DATAGRAM frame. +func (h *datagramQueue) Receive() ([]byte, error) { + select { + case data := <-h.rcvQueue: + return data, nil + case <-h.closed: + return nil, h.closeErr + } +} + func (h *datagramQueue) CloseWithError(e error) { h.closeErr = e close(h.closed) diff --git a/datagram_queue_test.go b/datagram_queue_test.go index 36e35a7e..0ff7b96e 100644 --- a/datagram_queue_test.go +++ b/datagram_queue_test.go @@ -3,7 +3,9 @@ package quic import ( "errors" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -16,39 +18,81 @@ var _ = Describe("Datagram Queue", func() { queued = make(chan struct{}, 100) queue = newDatagramQueue(func() { queued <- struct{}{} + }, utils.DefaultLogger) + }) + + Context("sending", func() { + It("returns nil when there's no datagram to send", func() { + Expect(queue.Get()).To(BeNil()) + }) + + It("queues a datagram", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})).To(Succeed()) + }() + + Eventually(queued).Should(HaveLen(1)) + Consistently(done).ShouldNot(BeClosed()) + f := queue.Get() + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foobar"))) + Eventually(done).Should(BeClosed()) + Expect(queue.Get()).To(BeNil()) + }) + + It("closes", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) + }() + + Consistently(errChan).ShouldNot(Receive()) + queue.CloseWithError(errors.New("test error")) + Eventually(errChan).Should(Receive(MatchError("test error"))) }) }) - It("returns nil when there's no datagram to send", func() { - Expect(queue.Get()).To(BeNil()) - }) + Context("receiving", func() { + It("receives DATAGRAM frames", func() { + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) + data, err := queue.Receive() + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foo"))) + data, err = queue.Receive() + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("bar"))) + }) - It("queues a datagram", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})).To(Succeed()) - }() + It("blocks until a frame is received", func() { + c := make(chan []byte, 1) + go func() { + defer GinkgoRecover() + data, err := queue.Receive() + Expect(err).ToNot(HaveOccurred()) + c <- data + }() - Eventually(queued).Should(HaveLen(1)) - Consistently(done).ShouldNot(BeClosed()) - f := queue.Get() - Expect(f).ToNot(BeNil()) - Expect(f.Data).To(Equal([]byte("foobar"))) - Eventually(done).Should(BeClosed()) - Expect(queue.Get()).To(BeNil()) - }) + Consistently(c).ShouldNot(Receive()) + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")}) + Eventually(c).Should(Receive(Equal([]byte("foobar")))) + }) - It("closes", func() { - errChan := make(chan error, 1) - go func() { - defer GinkgoRecover() - errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) - }() + It("closes", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := queue.Receive() + errChan <- err + }() - Consistently(errChan).ShouldNot(Receive()) - queue.CloseWithError(errors.New("test error")) - Eventually(errChan).Should(Receive(MatchError("test error"))) + Consistently(errChan).ShouldNot(Receive()) + queue.CloseWithError(errors.New("test error")) + Eventually(errChan).Should(Receive(MatchError("test error"))) + }) }) }) diff --git a/interface.go b/interface.go index b58c4111..c8e2fd41 100644 --- a/interface.go +++ b/interface.go @@ -191,6 +191,9 @@ type Session interface { // SendMessage sends a message as a datagram. // See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. SendMessage([]byte) error + // ReceiveMessage gets a message received in a datagram. + // See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. + ReceiveMessage() ([]byte, error) } // An EarlySession is a session that is handshaking. diff --git a/internal/mocks/quic/early_session.go b/internal/mocks/quic/early_session.go index 1e7fe4ce..0c81bead 100644 --- a/internal/mocks/quic/early_session.go +++ b/internal/mocks/quic/early_session.go @@ -197,6 +197,21 @@ func (mr *MockEarlySessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlySession)(nil).OpenUniStreamSync), arg0) } +// ReceiveMessage mocks base method +func (m *MockEarlySession) ReceiveMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveMessage indicates an expected call of ReceiveMessage +func (mr *MockEarlySessionMockRecorder) ReceiveMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlySession)(nil).ReceiveMessage)) +} + // RemoteAddr mocks base method func (m *MockEarlySession) RemoteAddr() net.Addr { m.ctrl.T.Helper() diff --git a/internal/protocol/params.go b/internal/protocol/params.go index cafcdcae..8aa143e3 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -130,6 +130,10 @@ const MaxAckFrameSize ByteCount = 1000 // The size is chosen such that a DATAGRAM frame fits into a QUIC packet. const MaxDatagramFrameSize ByteCount = 1200 +// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames. +// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. +const DatagramRcvQueueLen = 128 + // MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. // It also serves as a limit for the packet history. // If at any point we keep track of more ranges, old ranges are discarded. diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 06a7f340..d8b5ac0c 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -210,6 +210,21 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0) } +// ReceiveMessage mocks base method +func (m *MockQuicSession) ReceiveMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveMessage indicates an expected call of ReceiveMessage +func (mr *MockQuicSessionMockRecorder) ReceiveMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicSession)(nil).ReceiveMessage)) +} + // RemoteAddr mocks base method func (m *MockQuicSession) RemoteAddr() net.Addr { m.ctrl.T.Helper() diff --git a/packet_packer_test.go b/packet_packer_test.go index 1b01722c..9fbe557b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -13,6 +13,7 @@ import ( mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/golang/mock/gomock" @@ -90,7 +91,7 @@ var _ = Describe("Packet packer", func() { ackFramer = NewMockAckFrameSource(mockCtrl) sealingManager = NewMockSealingManager(mockCtrl) pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) - datagramQueue = newDatagramQueue(func() {}) + datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) packer = newPacketPacker( protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, diff --git a/session.go b/session.go index ad3b623f..c62a2795 100644 --- a/session.go +++ b/session.go @@ -513,7 +513,7 @@ func (s *session) preSetup() { s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) if s.config.EnableDatagrams { - s.datagramQueue = newDatagramQueue(s.scheduleSending) + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) } } @@ -1119,7 +1119,7 @@ func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, d case *wire.HandshakeDoneFrame: err = s.handleHandshakeDoneFrame() case *wire.DatagramFrame: - // TODO: handle DATRAGRAM frames + err = s.handleDatagramFrame(frame) default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } @@ -1258,6 +1258,14 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } +func (s *session) handleDatagramFrame(f *wire.DatagramFrame) error { + if f.Length(s.version) > protocol.MaxDatagramFrameSize { + return qerr.NewError(qerr.ProtocolViolation, "DATAGRAM frame too large") + } + s.datagramQueue.HandleDatagramFrame(f) + return nil +} + // closeLocal closes the session and send a CONNECTION_CLOSE containing the error func (s *session) closeLocal(e error) { s.closeOnce.Do(func() { @@ -1755,6 +1763,10 @@ func (s *session) SendMessage(p []byte) error { return nil } +func (s *session) ReceiveMessage() ([]byte, error) { + return s.datagramQueue.Receive() +} + func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() }