diff --git a/client_test.go b/client_test.go index 68bcc92b..23fd3301 100644 --- a/client_test.go +++ b/client_test.go @@ -465,6 +465,7 @@ var _ = Describe("Client", func() { ConnectionIDLength: 13, StatelessResetKey: []byte("foobar"), TokenStore: tokenStore, + EnableDatagrams: true, } c := populateClientConfig(config, false) Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) @@ -474,6 +475,7 @@ var _ = Describe("Client", func() { Expect(c.ConnectionIDLength).To(Equal(13)) Expect(c.StatelessResetKey).To(Equal([]byte("foobar"))) Expect(c.TokenStore).To(Equal(tokenStore)) + Expect(c.EnableDatagrams).To(BeTrue()) }) It("errors when the Config contains an invalid version", func() { diff --git a/config.go b/config.go index c8479c5c..05f8248e 100644 --- a/config.go +++ b/config.go @@ -105,6 +105,7 @@ func populateConfig(config *Config) *Config { ConnectionIDLength: config.ConnectionIDLength, StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index eb9a04cb..024e9eb2 100644 --- a/config_test.go +++ b/config_test.go @@ -69,6 +69,8 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf([]byte{1, 2, 3, 4})) case "KeepAlive": f.Set(reflect.ValueOf(true)) + case "EnableDatagrams": + f.Set(reflect.ValueOf(true)) case "Tracer": f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) default: diff --git a/datagram_queue.go b/datagram_queue.go new file mode 100644 index 00000000..92b5c3b0 --- /dev/null +++ b/datagram_queue.go @@ -0,0 +1,77 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type datagramQueue struct { + sendQueue chan *wire.DatagramFrame + rcvQueue chan []byte + + closeErr error + closed chan struct{} + + hasData func() + + logger utils.Logger +} + +func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { + return &datagramQueue{ + hasData: hasData, + sendQueue: make(chan *wire.DatagramFrame), + rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), + closed: make(chan struct{}), + logger: logger, + } +} + +// AddAndWait queues a new DATAGRAM frame for sending. +// It blocks until the frame has been dequeued. +func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { + h.hasData() + select { + case h.sendQueue <- f: + return nil + case <-h.closed: + return h.closeErr + } +} + +// Get dequeues a DATAGRAM frame for sending. +func (h *datagramQueue) Get() *wire.DatagramFrame { + select { + case f := <-h.sendQueue: + return f + default: + return nil + } +} + +// HandleDatagramFrame handles a received DATAGRAM frame. +func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { + data := make([]byte, len(f.Data)) + copy(data, f.Data) + select { + case h.rcvQueue <- data: + default: + h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + } +} + +// Receive gets a received DATAGRAM frame. +func (h *datagramQueue) Receive() ([]byte, error) { + select { + case data := <-h.rcvQueue: + return data, nil + case <-h.closed: + return nil, h.closeErr + } +} + +func (h *datagramQueue) CloseWithError(e error) { + h.closeErr = e + close(h.closed) +} diff --git a/datagram_queue_test.go b/datagram_queue_test.go new file mode 100644 index 00000000..0ff7b96e --- /dev/null +++ b/datagram_queue_test.go @@ -0,0 +1,98 @@ +package quic + +import ( + "errors" + + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Datagram Queue", func() { + var queue *datagramQueue + var queued chan struct{} + + BeforeEach(func() { + queued = make(chan struct{}, 100) + queue = newDatagramQueue(func() { + queued <- struct{}{} + }, utils.DefaultLogger) + }) + + Context("sending", func() { + It("returns nil when there's no datagram to send", func() { + Expect(queue.Get()).To(BeNil()) + }) + + It("queues a datagram", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})).To(Succeed()) + }() + + Eventually(queued).Should(HaveLen(1)) + Consistently(done).ShouldNot(BeClosed()) + f := queue.Get() + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foobar"))) + Eventually(done).Should(BeClosed()) + Expect(queue.Get()).To(BeNil()) + }) + + It("closes", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) + }() + + Consistently(errChan).ShouldNot(Receive()) + queue.CloseWithError(errors.New("test error")) + Eventually(errChan).Should(Receive(MatchError("test error"))) + }) + }) + + Context("receiving", func() { + It("receives DATAGRAM frames", func() { + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")}) + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")}) + data, err := queue.Receive() + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foo"))) + data, err = queue.Receive() + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("bar"))) + }) + + It("blocks until a frame is received", func() { + c := make(chan []byte, 1) + go func() { + defer GinkgoRecover() + data, err := queue.Receive() + Expect(err).ToNot(HaveOccurred()) + c <- data + }() + + Consistently(c).ShouldNot(Receive()) + queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")}) + Eventually(c).Should(Receive(Equal([]byte("foobar")))) + }) + + It("closes", func() { + errChan := make(chan error, 1) + go func() { + defer GinkgoRecover() + _, err := queue.Receive() + errChan <- err + }() + + Consistently(errChan).ShouldNot(Receive()) + queue.CloseWithError(errors.New("test error")) + Eventually(errChan).Should(Receive(MatchError("test error"))) + }) + }) +}) diff --git a/fuzzing/frames/fuzz.go b/fuzzing/frames/fuzz.go index 1df2c81f..cd0409bc 100644 --- a/fuzzing/frames/fuzz.go +++ b/fuzzing/frames/fuzz.go @@ -33,7 +33,7 @@ func Fuzz(data []byte) int { encLevel := toEncLevel(data[0]) data = data[PrefixLen:] - parser := wire.NewFrameParser(version) + parser := wire.NewFrameParser(true, version) parser.SetAckDelayExponent(protocol.DefaultAckDelayExponent) r := bytes.NewReader(data) diff --git a/http3/client.go b/http3/client.go index 9bacd679..d77efc04 100644 --- a/http3/client.go +++ b/http3/client.go @@ -250,7 +250,7 @@ func (c *client) doRequest( return nil, newConnError(errorGeneralProtocolError, err) } - connState := qtls.ToTLSConnectionState(c.session.ConnectionState()) + connState := qtls.ToTLSConnectionState(c.session.ConnectionState().TLS) res := &http.Response{ Proto: "HTTP/3", ProtoMajor: 3, diff --git a/http3/client_test.go b/http3/client_test.go index c107a9cb..29b6909e 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -16,7 +16,6 @@ import ( mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/marten-seemann/qpack" @@ -240,7 +239,7 @@ var _ = Describe("Client", func() { gomock.InOrder( sess.EXPECT().HandshakeComplete().Return(handshakeCtx), sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}), + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() @@ -410,7 +409,7 @@ var _ = Describe("Client", func() { req := request.WithContext(ctx) sess.EXPECT().HandshakeComplete().Return(handshakeCtx) sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -473,7 +472,7 @@ var _ = Describe("Client", func() { It("decompresses the response", func() { sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") @@ -499,7 +498,7 @@ var _ = Describe("Client", func() { It("only decompresses the response if the response contains the right content-encoding header", func() { sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Write([]byte("not gzipped")) diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go new file mode 100644 index 00000000..371f4c65 --- /dev/null +++ b/integrationtests/self/datagram_test.go @@ -0,0 +1,141 @@ +package self_test + +import ( + "context" + "encoding/binary" + "fmt" + mrand "math/rand" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go" + quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Datagram test", func() { + for _, v := range protocol.SupportedVersions { + version := v + + Context(fmt.Sprintf("with QUIC version %s", version), func() { + const num = 100 + + var ( + proxy *quicproxy.QuicProxy + serverConn, clientConn *net.UDPConn + dropped, total int32 + ) + + startServerAndProxy := func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + serverConn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + ln, err := quic.Listen( + serverConn, + getTLSConfig(), + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + Versions: []protocol.VersionNumber{version}, + }), + ) + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + sess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(sess.ConnectionState().SupportsDatagrams).To(BeTrue()) + + var wg sync.WaitGroup + wg.Add(num) + for i := 0; i < num; i++ { + go func(i int) { + defer GinkgoRecover() + defer wg.Done() + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(i)) + Expect(sess.SendMessage(b)).To(Succeed()) + }(i) + } + wg.Wait() + }() + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + // drop 10% of Short Header packets sent from the server + DropPacket: func(dir quicproxy.Direction, packet []byte) bool { + if dir == quicproxy.DirectionIncoming { + return false + } + // don't drop Long Header packets + if packet[0]&0x80 == 1 { + return false + } + drop := mrand.Int()%10 == 0 + if drop { + atomic.AddInt32(&dropped, 1) + } + atomic.AddInt32(&total, 1) + return drop + }, + }) + Expect(err).ToNot(HaveOccurred()) + } + + BeforeEach(func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + clientConn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + Expect(proxy.Close()).To(Succeed()) + }) + + It("sends datagrams", func() { + startServerAndProxy() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + sess, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + Versions: []protocol.VersionNumber{version}, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(sess.ConnectionState().SupportsDatagrams).To(BeTrue()) + var counter int + for { + // Close the session if no message is received for 100 ms. + timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { + sess.CloseWithError(0, "") + }) + if _, err := sess.ReceiveMessage(); err != nil { + break + } + timer.Stop() + counter++ + } + + numDropped := int(atomic.LoadInt32(&dropped)) + expVal := num - numDropped + fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, atomic.LoadInt32(&total)) + fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num) + Expect(counter).To(And( + BeNumerically(">", expVal*9/10), + BeNumerically("<", num), + )) + }) + }) + } +}) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index d1c5a969..cd387044 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -170,7 +170,7 @@ var _ = Describe("Handshake tests", func() { data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - Expect(sess.ConnectionState().CipherSuite).To(Equal(suiteID)) + Expect(sess.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) Expect(sess.CloseWithError(0, "")).To(Succeed()) }) } @@ -369,7 +369,7 @@ var _ = Describe("Handshake tests", func() { sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) cs := sess.ConnectionState() - Expect(cs.NegotiatedProtocol).To(Equal(alpn)) + Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) close(done) }() @@ -381,7 +381,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) defer sess.CloseWithError(0, "") cs := sess.ConnectionState() - Expect(cs.NegotiatedProtocol).To(Equal(alpn)) + Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) Eventually(done).Should(BeClosed()) Expect(ln.Close()).To(Succeed()) }) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 0ab8b434..fbccd4a1 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -65,11 +65,11 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) var sessionKey string Eventually(puts).Should(Receive(&sessionKey)) - Expect(sess.ConnectionState().DidResume).To(BeFalse()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) serverSess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) sess, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -78,11 +78,11 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) - Expect(sess.ConnectionState().DidResume).To(BeTrue()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeTrue()) serverSess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeTrue()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeTrue()) }) It("doesn't use session resumption, if the config disables it", func() { @@ -104,11 +104,11 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) - Expect(sess.ConnectionState().DidResume).To(BeFalse()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) serverSess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) sess, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -116,10 +116,10 @@ var _ = Describe("TLS session resumption", func() { nil, ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.ConnectionState().DidResume).To(BeFalse()) + Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) serverSess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) + Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) }) }) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 0b17be4c..ea4110ed 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -15,6 +15,7 @@ import ( "math/big" mrand "math/rand" "os" + "strconv" "sync" "testing" "time" @@ -317,6 +318,15 @@ func debugLog() bool { return len(logFileName) > 0 } +func scaleDuration(d time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * d +} + func TestSelf(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Self integration tests") diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 17129a3b..3baeefc1 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -8,9 +8,7 @@ import ( "io/ioutil" mrand "math/rand" "net" - "os" "runtime/pprof" - "strconv" "strings" "sync/atomic" "time" @@ -181,15 +179,6 @@ var _ = Describe("Timeout tests", func() { Context("timing out at the right time", func() { var idleTimeout time.Duration - scaleDuration := func(d time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * d - } - BeforeEach(func() { idleTimeout = scaleDuration(100 * time.Millisecond) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index b4d8ca3a..3ab95efb 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -86,7 +86,7 @@ var _ = Describe("0-RTT", func() { data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(testdata)) - Expect(sess.ConnectionState().Used0RTT).To(Equal(expect0RTT)) + Expect(sess.ConnectionState().TLS.Used0RTT).To(Equal(expect0RTT)) close(done) }() @@ -101,7 +101,7 @@ var _ = Describe("0-RTT", func() { _, err = str.Write(testdata) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - Expect(sess.ConnectionState().Used0RTT).To(Equal(expect0RTT)) + Expect(sess.ConnectionState().TLS.Used0RTT).To(Equal(expect0RTT)) Eventually(done).Should(BeClosed()) } diff --git a/interface.go b/interface.go index 99afd92c..eedb18fc 100644 --- a/interface.go +++ b/interface.go @@ -141,8 +141,6 @@ type StreamError interface { ErrorCode() ErrorCode } -type ConnectionState = handshake.ConnectionState - // A Session is a QUIC connection between two peers. type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. @@ -189,6 +187,13 @@ type Session interface { // It blocks until the handshake completes. // Warning: This API should not be considered stable and might change soon. ConnectionState() ConnectionState + + // SendMessage sends a message as a datagram. + // See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. + SendMessage([]byte) error + // ReceiveMessage gets a message received in a datagram. + // See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. + ReceiveMessage() ([]byte, error) } // An EarlySession is a session that is handshaking. @@ -261,7 +266,16 @@ type Config struct { StatelessResetKey []byte // KeepAlive defines whether this peer will periodically send a packet to keep the connection alive. KeepAlive bool - Tracer logging.Tracer + // See https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/. + // Datagrams will only be available when both peers enable datagram support. + EnableDatagrams bool + Tracer logging.Tracer +} + +// ConnectionState records basic details about a QUIC connection +type ConnectionState struct { + TLS handshake.ConnectionState + SupportsDatagrams bool } // A Listener for incoming QUIC connections diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index 8195c71e..9f1d65d2 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -30,9 +30,9 @@ var _ = Describe("Stream Flow controller", func() { Context("Constructor", func() { rttStats := &utils.RTTStats{} - receiveWindow := protocol.ByteCount(2000) - maxReceiveWindow := protocol.ByteCount(3000) - sendWindow := protocol.ByteCount(4000) + const receiveWindow protocol.ByteCount = 2000 + const maxReceiveWindow protocol.ByteCount = 3000 + const sendWindow protocol.ByteCount = 4000 It("sets the send and receive windows", func() { cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger) @@ -50,7 +50,7 @@ var _ = Describe("Stream Flow controller", func() { queued = true } - cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger) + cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func() {}, nil, utils.DefaultLogger) fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController) fc.AddBytesRead(receiveWindow) Expect(queued).To(BeTrue()) diff --git a/internal/logutils/frame.go b/internal/logutils/frame.go index b1d3e672..6e0fd311 100644 --- a/internal/logutils/frame.go +++ b/internal/logutils/frame.go @@ -23,6 +23,10 @@ func ConvertFrame(frame wire.Frame) logging.Frame { Length: f.DataLen(), Fin: f.Fin, } + case *wire.DatagramFrame: + return &logging.DatagramFrame{ + Length: logging.ByteCount(len(f.Data)), + } default: return logging.Frame(frame) } diff --git a/internal/logutils/frame_test.go b/internal/logutils/frame_test.go index dd6b14e4..eb7a0f8e 100644 --- a/internal/logutils/frame_test.go +++ b/internal/logutils/frame_test.go @@ -34,6 +34,13 @@ var _ = Describe("CRYPTO frame", func() { Expect(sf.Fin).To(BeTrue()) }) + It("converts DATAGRAM frames", func() { + f := ConvertFrame(&wire.DatagramFrame{Data: []byte("foobar")}) + Expect(f).To(BeAssignableToTypeOf(&logging.DatagramFrame{})) + df := f.(*logging.DatagramFrame) + Expect(df.Length).To(Equal(logging.ByteCount(6))) + }) + It("converts other frames", func() { f := ConvertFrame(&wire.MaxDataFrame{MaximumData: 1234}) Expect(f).To(BeAssignableToTypeOf(&logging.MaxDataFrame{})) diff --git a/internal/mocks/quic/early_session.go b/internal/mocks/quic/early_session.go index 78d44ea1..0c81bead 100644 --- a/internal/mocks/quic/early_session.go +++ b/internal/mocks/quic/early_session.go @@ -197,6 +197,21 @@ func (mr *MockEarlySessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlySession)(nil).OpenUniStreamSync), arg0) } +// ReceiveMessage mocks base method +func (m *MockEarlySession) ReceiveMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveMessage indicates an expected call of ReceiveMessage +func (mr *MockEarlySessionMockRecorder) ReceiveMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlySession)(nil).ReceiveMessage)) +} + // RemoteAddr mocks base method func (m *MockEarlySession) RemoteAddr() net.Addr { m.ctrl.T.Helper() @@ -210,3 +225,17 @@ func (mr *MockEarlySessionMockRecorder) RemoteAddr() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlySession)(nil).RemoteAddr)) } + +// SendMessage mocks base method +func (m *MockEarlySession) SendMessage(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMessage indicates an expected call of SendMessage +func (mr *MockEarlySessionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlySession)(nil).SendMessage), arg0) +} diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 12567fa7..f5b56eaa 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -132,6 +132,15 @@ const MaxPostHandshakeCryptoFrameSize = 1000 // but must ensure that a maximum size ACK frame fits into one packet. const MaxAckFrameSize ByteCount = 1000 +// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame as defined in +// https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. +// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. +const MaxDatagramFrameSize ByteCount = 1200 + +// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames. +// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. +const DatagramRcvQueueLen = 128 + // MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. // It also serves as a limit for the packet history. // If at any point we keep track of more ranges, old ranges are discarded. diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 22e4a6fb..74e9e98c 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -44,11 +44,14 @@ const ( ) // A ByteCount in QUIC -type ByteCount uint64 +type ByteCount int64 // MaxByteCount is the maximum value of a ByteCount const MaxByteCount = ByteCount(1<<62 - 1) +// InvalidByteCount is an invalid byte count +const InvalidByteCount ByteCount = -1 + // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go new file mode 100644 index 00000000..bd5b67c0 --- /dev/null +++ b/internal/wire/datagram_frame.go @@ -0,0 +1,85 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A DatagramFrame is a DATAGRAM frame +type DatagramFrame struct { + DataLenPresent bool + Data []byte +} + +func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &DatagramFrame{} + f.DataLenPresent = typeByte&0x1 > 0 + + var length uint64 + if f.DataLenPresent { + var err error + len, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + if len > uint64(r.Len()) { + return nil, io.EOF + } + length = len + } else { + length = uint64(r.Len()) + } + f.Data = make([]byte, length) + if _, err := io.ReadFull(r, f.Data); err != nil { + return nil, err + } + return f, nil +} + +func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x30) + if f.DataLenPresent { + typeByte ^= 0x1 + } + b.WriteByte(typeByte) + if f.DataLenPresent { + utils.WriteVarInt(b, uint64(len(f.Data))) + } + b.Write(f.Data) + return nil +} + +// MaxDataLen returns the maximum data length +func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := protocol.ByteCount(1) + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && utils.VarIntLen(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// Length of a written frame +func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + length := 1 + protocol.ByteCount(len(f.Data)) + if f.DataLenPresent { + length += utils.VarIntLen(uint64(len(f.Data))) + } + return length +} diff --git a/internal/wire/datagram_frame_test.go b/internal/wire/datagram_frame_test.go new file mode 100644 index 00000000..f4d9e3f2 --- /dev/null +++ b/internal/wire/datagram_frame_test.go @@ -0,0 +1,153 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM frame", func() { + Context("when parsing", func() { + It("parses a frame containing a length", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(0x6)...) // length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := parseDatagramFrame(r, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.DataLenPresent).To(BeTrue()) + Expect(r.Len()).To(BeZero()) + }) + + It("parses a frame without length", func() { + data := []byte{0x30} + data = append(data, []byte("Lorem ipsum dolor sit amet")...) + r := bytes.NewReader(data) + frame, err := parseDatagramFrame(r, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) + Expect(frame.DataLenPresent).To(BeFalse()) + Expect(r.Len()).To(BeZero()) + }) + + It("errors when the length is longer than the rest of the frame", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(0x6)...) // length + data = append(data, []byte("fooba")...) + r := bytes.NewReader(data) + _, err := parseDatagramFrame(r, versionIETFFrames) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors on EOFs", func() { + data := []byte{0x30 ^ 0x1} + data = append(data, encodeVarInt(6)...) // length + data = append(data, []byte("foobar")...) + _, err := parseDatagramFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("when writing", func() { + It("writes a frame with length", func() { + f := &DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + expected := []byte{0x30 ^ 0x1} + expected = append(expected, encodeVarInt(0x6)...) + expected = append(expected, []byte("foobar")...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + + It("writes a frame without length", func() { + f := &DatagramFrame{Data: []byte("Lorem ipsum")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + expected := []byte{0x30} + expected = append(expected, []byte("Lorem ipsum")...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + }) + + Context("length", func() { + It("has the right length for a frame with length", func() { + f := &DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + Expect(f.Length(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(6) + 6)) + }) + + It("has the right length for a frame without length", func() { + f := &DatagramFrame{Data: []byte("foobar")} + Expect(f.Length(versionIETFFrames)).To(Equal(protocol.ByteCount(1 + 6))) + }) + }) + + Context("max data length", func() { + const maxSize = 3000 + + It("returns a data length such that the resulting frame has the right size, if data length is not present", func() { + data := make([]byte, maxSize) + f := &DatagramFrame{} + b := &bytes.Buffer{} + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(b.Len()).To(Equal(i)) + } + }) + + It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { + data := make([]byte, maxSize) + f := &DatagramFrame{DataLenPresent: true} + b := &bytes.Buffer{} + var frameOneByteTooSmallCounter int + for i := 1; i < 3000; i++ { + b.Reset() + f.Data = nil + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames) + if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written + // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size + f.Data = []byte{0} + Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(b.Len()).To(BeNumerically(">", i)) + continue + } + f.Data = data[:int(maxDataLen)] + Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + // There's *one* pathological case, where a data length of x can be encoded into 1 byte + // but a data lengths of x+1 needs 2 bytes + // In that case, it's impossible to create a STREAM frame of the desired size + if b.Len() == i-1 { + frameOneByteTooSmallCounter++ + continue + } + Expect(b.Len()).To(Equal(i)) + } + Expect(frameOneByteTooSmallCounter).To(Equal(1)) + }) + }) +}) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 9659883d..a858989e 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -13,12 +13,17 @@ import ( type frameParser struct { ackDelayExponent uint8 + supportsDatagrams bool + version protocol.VersionNumber } // NewFrameParser creates a new frame parser. -func NewFrameParser(v protocol.VersionNumber) FrameParser { - return &frameParser{version: v} +func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser { + return &frameParser{ + supportsDatagrams: supportsDatagrams, + version: v, + } } // ParseNextFrame parses the next frame @@ -87,6 +92,12 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protoc frame, err = parseConnectionCloseFrame(r, p.version) case 0x1e: frame, err = parseHandshakeDoneFrame(r, p.version) + case 0x30, 0x31: + if p.supportsDatagrams { + frame, err = parseDatagramFrame(r, p.version) + break + } + fallthrough default: err = errors.New("unknown frame type") } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index bfbbfd7a..800af60f 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -18,7 +18,7 @@ var _ = Describe("Frame parsing", func() { BeforeEach(func() { buf = &bytes.Buffer{} - parser = NewFrameParser(versionIETFFrames) + parser = NewFrameParser(true, versionIETFFrames) }) It("returns nil if there's nothing more to read", func() { @@ -280,6 +280,24 @@ var _ = Describe("Frame parsing", func() { Expect(frame).To(Equal(f)) }) + It("unpacks DATAGRAM frames", func() { + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when DATAGRAM frames are not supported", func() { + parser = NewFrameParser(false, versionIETFFrames) + f := &DatagramFrame{Data: []byte("foobar")} + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x30): unknown frame type")) + }) + It("errors on invalid type", func() { _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x42): unknown frame type")) @@ -318,6 +336,7 @@ var _ = Describe("Frame parsing", func() { &PathResponseFrame{}, &ConnectionCloseFrame{}, &HandshakeDoneFrame{}, + &DatagramFrame{}, } var framesSerialized [][]byte diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 4c2c067d..45929717 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -46,11 +46,12 @@ var _ = Describe("Transport Parameters", func() { MaxAckDelay: 37 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, ActiveConnectionIDLimit: 123, + MaxDatagramFrameSize: 876, } - Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: 0xdeadbeef, InitialSourceConnectionID: 0xdecafbad, RetrySourceConnectionID: 0xdeadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00}")) + Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: 0xdeadbeef, InitialSourceConnectionID: 0xdecafbad, RetrySourceConnectionID: 0xdeadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876}")) }) - It("has a string representation, if there's no stateless reset token and no Retry source connection id", func() { + It("has a string representation, if there's no stateless reset token, no Retry source connection id and no datagram support", func() { p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, @@ -64,6 +65,7 @@ var _ = Describe("Transport Parameters", func() { AckDelayExponent: 14, MaxAckDelay: 37 * time.Second, ActiveConnectionIDLimit: 89, + MaxDatagramFrameSize: protocol.InvalidByteCount, } Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: 0xdeadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89}")) }) @@ -87,6 +89,7 @@ var _ = Describe("Transport Parameters", func() { AckDelayExponent: 13, MaxAckDelay: 42 * time.Millisecond, ActiveConnectionIDLimit: getRandomValue(), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), } data := params.Marshal(protocol.PerspectiveServer) @@ -107,6 +110,7 @@ var _ = Describe("Transport Parameters", func() { Expect(p.AckDelayExponent).To(Equal(uint8(13))) Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + Expect(p.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) }) It("doesn't marshal a retry_source_connection_id, if no Retry was performed", func() { diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 1f81eb23..fbb10d6c 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -42,6 +42,8 @@ const ( activeConnectionIDLimitParameterID transportParameterID = 0xe initialSourceConnectionIDParameterID transportParameterID = 0xf retrySourceConnectionIDParameterID transportParameterID = 0x10 + // https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/ + maxDatagramFrameSizeParameterID transportParameterID = 0x20 ) // PreferredAddress is the value encoding in the preferred_address transport parameter @@ -81,6 +83,8 @@ type TransportParameters struct { StatelessResetToken *protocol.StatelessResetToken ActiveConnectionIDLimit uint64 + + MaxDatagramFrameSize protocol.ByteCount } // Unmarshal the transport parameters @@ -96,12 +100,14 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec var parameterIDs []transportParameterID var ( - readAckDelayExponent bool - readMaxAckDelay bool readOriginalDestinationConnectionID bool readInitialSourceConnectionID bool ) + p.AckDelayExponent = protocol.DefaultAckDelayExponent + p.MaxAckDelay = protocol.DefaultMaxAckDelay + p.MaxDatagramFrameSize = protocol.InvalidByteCount + for r.Len() > 0 { paramIDInt, err := utils.ReadVarInt(r) if err != nil { @@ -118,12 +124,10 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec parameterIDs = append(parameterIDs, paramID) switch paramID { case ackDelayExponentParameterID: - readAckDelayExponent = true if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { return err } case maxAckDelayParameterID: - readMaxAckDelay = true if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { return err } @@ -135,7 +139,8 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec initialMaxStreamsUniParameterID, maxIdleTimeoutParameterID, maxUDPPayloadSizeParameterID, - activeConnectionIDLimitParameterID: + activeConnectionIDLimitParameterID, + maxDatagramFrameSizeParameterID: if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { return err } @@ -185,12 +190,6 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { return errors.New("missing original_destination_connection_id") } - if !readAckDelayExponent { - p.AckDelayExponent = protocol.DefaultAckDelayExponent - } - if !readMaxAckDelay { - p.MaxAckDelay = protocol.DefaultMaxAckDelay - } if p.MaxUDPPayloadSize == 0 { p.MaxUDPPayloadSize = protocol.MaxByteCount } @@ -305,6 +304,8 @@ func (p *TransportParameters) readNumericTransportParameter( p.MaxAckDelay = maxAckDelay case activeConnectionIDLimitParameterID: p.ActiveConnectionIDLimit = val + case maxDatagramFrameSizeParameterID: + p.MaxDatagramFrameSize = protocol.ByteCount(val) default: return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) } @@ -391,6 +392,9 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { utils.WriteVarInt(b, uint64(p.RetrySourceConnectionID.Len())) b.Write(p.RetrySourceConnectionID.Bytes()) } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } return b.Bytes() } @@ -463,6 +467,10 @@ func (p *TransportParameters) String() string { logString += ", StatelessResetToken: %#x" logParams = append(logParams, *p.StatelessResetToken) } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + logString += ", MaxDatagramFrameSize: %d" + logParams = append(logParams, p.MaxDatagramFrameSize) + } logString += "}" return fmt.Sprintf(logString, logParams...) } diff --git a/logging/frame.go b/logging/frame.go index dfb875f0..75705092 100644 --- a/logging/frame.go +++ b/logging/frame.go @@ -59,3 +59,8 @@ type StreamFrame struct { Length ByteCount Fin bool } + +// A DatagramFrame is a DATAGRAM frame. +type DatagramFrame struct { + Length ByteCount +} diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 656f77a2..d8b5ac0c 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -210,6 +210,21 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0) } +// ReceiveMessage mocks base method +func (m *MockQuicSession) ReceiveMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveMessage indicates an expected call of ReceiveMessage +func (mr *MockQuicSessionMockRecorder) ReceiveMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicSession)(nil).ReceiveMessage)) +} + // RemoteAddr mocks base method func (m *MockQuicSession) RemoteAddr() net.Addr { m.ctrl.T.Helper() @@ -224,6 +239,20 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr)) } +// SendMessage mocks base method +func (m *MockQuicSession) SendMessage(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMessage indicates an expected call of SendMessage +func (mr *MockQuicSessionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQuicSession)(nil).SendMessage), arg0) +} + // destroy mocks base method func (m *MockQuicSession) destroy(arg0 error) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index 858fcd71..6599a277 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -156,6 +156,7 @@ type packetPacker struct { pnManager packetNumberManager framer frameSource acks ackFrameSource + datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue maxPacketSize protocol.ByteCount @@ -175,6 +176,7 @@ func newPacketPacker( cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, + datagramQueue *datagramQueue, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { @@ -185,6 +187,7 @@ func newPacketPacker( initialStream: initialStream, handshakeStream: handshakeStream, retransmissionQueue: retransmissionQueue, + datagramQueue: datagramQueue, perspective: perspective, version: version, framer: framer, @@ -576,10 +579,25 @@ func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { payload := &payload{} + + var hasDatagram bool + if p.datagramQueue != nil { + if datagram := p.datagramQueue.Get(); datagram != nil { + payload.frames = append(payload.frames, ackhandler.Frame{ + Frame: datagram, + // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + OnLost: func(wire.Frame) {}, + }) + payload.length += datagram.Length(p.version) + hasDatagram = true + } + } + var ack *wire.AckFrame hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasAppData() - if ackAllowed { + // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued + if !hasDatagram && ackAllowed { ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) if ack != nil { payload.ack = ack diff --git a/packet_packer_test.go b/packet_packer_test.go index 49c9f23b..9fbe557b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -7,17 +7,17 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -30,6 +30,7 @@ var _ = Describe("Packet packer", func() { var ( packer *packetPacker retransmissionQueue *retransmissionQueue + datagramQueue *datagramQueue framer *MockFrameSource ackFramer *MockAckFrameSource initialStream *MockCryptoStream @@ -90,6 +91,7 @@ var _ = Describe("Packet packer", func() { ackFramer = NewMockAckFrameSource(mockCtrl) sealingManager = NewMockSealingManager(mockCtrl) pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) + datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) packer = newPacketPacker( protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, @@ -102,6 +104,7 @@ var _ = Describe("Packet packer", func() { sealingManager, framer, ackFramer, + datagramQueue, protocol.PerspectiveServer, version, ) @@ -537,6 +540,33 @@ var _ = Describe("Packet packer", func() { Expect(p.buffer.Len()).ToNot(BeZero()) }) + It("packs DATAGRAM frames", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + f := &wire.DatagramFrame{ + DataLenPresent: true, + Data: []byte("foobar"), + } + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + datagramQueue.AddAndWait(f) + }() + // make sure the DATAGRAM has actually been queued + time.Sleep(scaleDuration(20 * time.Millisecond)) + + framer.EXPECT().HasData() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + Expect(p.frames[0].Frame).To(Equal(f)) + Expect(p.buffer.Data).ToNot(BeEmpty()) + Eventually(done).Should(BeClosed()) + }) + It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) @@ -588,7 +618,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(secondPayloadByte).To(Equal(byte(0))) // ... followed by the PING - frameParser := wire.NewFrameParser(packer.version) + frameParser := wire.NewFrameParser(false, packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) @@ -625,7 +655,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the STREAM frame - frameParser := wire.NewFrameParser(packer.version) + frameParser := wire.NewFrameParser(true, packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) @@ -1137,7 +1167,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(secondPayloadByte).To(Equal(byte(0))) // ... followed by the PING - frameParser := wire.NewFrameParser(packer.version) + frameParser := wire.NewFrameParser(false, packer.version) frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) diff --git a/qlog/frame.go b/qlog/frame.go index 4cfb6c89..498d3982 100644 --- a/qlog/frame.go +++ b/qlog/frame.go @@ -57,6 +57,8 @@ func (f frame) MarshalJSONObject(enc *gojay.Encoder) { marshalConnectionCloseFrame(enc, frame) case *logging.HandshakeDoneFrame: marshalHandshakeDoneFrame(enc, frame) + case *logging.DatagramFrame: + marshalDatagramFrame(enc, frame) default: panic("unknown frame type") } @@ -218,3 +220,8 @@ func marshalConnectionCloseFrame(enc *gojay.Encoder, f *logging.ConnectionCloseF func marshalHandshakeDoneFrame(enc *gojay.Encoder, _ *logging.HandshakeDoneFrame) { enc.StringKey("frame_type", "handshake_done") } + +func marshalDatagramFrame(enc *gojay.Encoder, f *logging.DatagramFrame) { + enc.StringKey("frame_type", "datagram") + enc.Int64Key("length", int64(f.Length)) +} diff --git a/qlog/frame_test.go b/qlog/frame_test.go index b773e137..bab01f01 100644 --- a/qlog/frame_test.go +++ b/qlog/frame_test.go @@ -364,4 +364,14 @@ var _ = Describe("Frames", func() { }, ) }) + + It("marshals DATAGRAM frames", func() { + check( + &logging.DatagramFrame{Length: 1337}, + map[string]interface{}{ + "frame_type": "datagram", + "length": 1337, + }, + ) + }) }) diff --git a/server_test.go b/server_test.go index 210ad6fc..9bed2523 100644 --- a/server_test.go +++ b/server_test.go @@ -521,7 +521,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) Expect(err).ToNot(HaveOccurred()) - f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := f.(*wire.ConnectionCloseFrame) diff --git a/session.go b/session.go index 68a6cecc..a1994b1a 100644 --- a/session.go +++ b/session.go @@ -204,6 +204,8 @@ type session struct { keepAlivePingSent bool keepAliveInterval time.Duration + datagramQueue *datagramQueue + logID string tracer logging.ConnectionTracer logger utils.Logger @@ -295,6 +297,9 @@ var newSession = func( InitialSourceConnectionID: srcConnID, RetrySourceConnectionID: retrySrcConnID, } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } if s.tracer != nil { s.tracer.SentTransportParameters(params) } @@ -333,6 +338,7 @@ var newSession = func( cs, s.framer, s.receivedPacketHandler, + s.datagramQueue, s.perspective, s.version, ) @@ -414,6 +420,9 @@ var newClientSession = func( ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } if s.tracer != nil { s.tracer.SentTransportParameters(params) } @@ -452,6 +461,7 @@ var newClientSession = func( cs, s.framer, s.receivedPacketHandler, + s.datagramQueue, s.perspective, s.version, ) @@ -471,7 +481,7 @@ var newClientSession = func( func (s *session) preSetup() { s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue(s.version) - s.frameParser = wire.NewFrameParser(s.version) + s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.InitialMaxData, @@ -501,6 +511,9 @@ func (s *session) preSetup() { s.sessionCreationTime = now s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) + if s.config.EnableDatagrams { + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) + } } // run the session main loop @@ -633,8 +646,15 @@ func (s *session) Context() context.Context { return s.ctx } +func (s *session) supportsDatagrams() bool { + return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount +} + func (s *session) ConnectionState() ConnectionState { - return s.cryptoStreamHandler.ConnectionState() + return ConnectionState{ + TLS: s.cryptoStreamHandler.ConnectionState(), + SupportsDatagrams: s.supportsDatagrams(), + } } // Time when the next keep-alive packet should be sent. @@ -1104,6 +1124,8 @@ func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, d err = s.handleRetireConnectionIDFrame(frame, destConnID) case *wire.HandshakeDoneFrame: err = s.handleHandshakeDoneFrame() + case *wire.DatagramFrame: + err = s.handleDatagramFrame(frame) default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } @@ -1245,6 +1267,14 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } +func (s *session) handleDatagramFrame(f *wire.DatagramFrame) error { + if f.Length(s.version) > protocol.MaxDatagramFrameSize { + return qerr.NewError(qerr.ProtocolViolation, "DATAGRAM frame too large") + } + s.datagramQueue.HandleDatagramFrame(f) + return nil +} + // closeLocal closes the session and send a CONNECTION_CLOSE containing the error func (s *session) closeLocal(e error) { s.closeOnce.Do(func() { @@ -1307,6 +1337,9 @@ func (s *session) handleCloseError(closeErr closeError) { s.streamsMap.CloseWithError(quicErr) s.connIDManager.Close() + if s.datagramQueue != nil { + s.datagramQueue.CloseWithError(quicErr) + } if s.tracer != nil { // timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts) @@ -1731,6 +1764,21 @@ func (s *session) onStreamCompleted(id protocol.StreamID) { } } +func (s *session) SendMessage(p []byte) error { + f := &wire.DatagramFrame{DataLenPresent: true} + if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { + return errors.New("message too large") + } + f.Data = make([]byte, len(p)) + copy(f.Data, p) + s.datagramQueue.AddAndWait(f) + return nil +} + +func (s *session) ReceiveMessage() ([]byte, error) { + return s.datagramQueue.Receive() +} + func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 376e1811..d59045d0 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -39,7 +39,7 @@ var _ = Describe("Streams Map (incoming)", func() { checkFrameSerialization := func(f wire.Frame) { b := &bytes.Buffer{} ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) - frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) + frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) ExpectWithOffset(1, err).ToNot(HaveOccurred()) Expect(f).To(Equal(frame)) }