Merge pull request #2162 from lucas-clemente/datagram

implement the datagram draft
This commit is contained in:
Marten Seemann 2020-12-17 11:22:40 +07:00 committed by GitHub
commit 9693a46d31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 879 additions and 67 deletions

View file

@ -465,6 +465,7 @@ var _ = Describe("Client", func() {
ConnectionIDLength: 13,
StatelessResetKey: []byte("foobar"),
TokenStore: tokenStore,
EnableDatagrams: true,
}
c := populateClientConfig(config, false)
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
@ -474,6 +475,7 @@ var _ = Describe("Client", func() {
Expect(c.ConnectionIDLength).To(Equal(13))
Expect(c.StatelessResetKey).To(Equal([]byte("foobar")))
Expect(c.TokenStore).To(Equal(tokenStore))
Expect(c.EnableDatagrams).To(BeTrue())
})
It("errors when the Config contains an invalid version", func() {

View file

@ -105,6 +105,7 @@ func populateConfig(config *Config) *Config {
ConnectionIDLength: config.ConnectionIDLength,
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
Tracer: config.Tracer,
}
}

View file

@ -69,6 +69,8 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf([]byte{1, 2, 3, 4}))
case "KeepAlive":
f.Set(reflect.ValueOf(true))
case "EnableDatagrams":
f.Set(reflect.ValueOf(true))
case "Tracer":
f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
default:

77
datagram_queue.go Normal file
View file

@ -0,0 +1,77 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type datagramQueue struct {
sendQueue chan *wire.DatagramFrame
rcvQueue chan []byte
closeErr error
closed chan struct{}
hasData func()
logger utils.Logger
}
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
return &datagramQueue{
hasData: hasData,
sendQueue: make(chan *wire.DatagramFrame),
rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen),
closed: make(chan struct{}),
logger: logger,
}
}
// AddAndWait queues a new DATAGRAM frame for sending.
// It blocks until the frame has been dequeued.
func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error {
h.hasData()
select {
case h.sendQueue <- f:
return nil
case <-h.closed:
return h.closeErr
}
}
// Get dequeues a DATAGRAM frame for sending.
func (h *datagramQueue) Get() *wire.DatagramFrame {
select {
case f := <-h.sendQueue:
return f
default:
return nil
}
}
// HandleDatagramFrame handles a received DATAGRAM frame.
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
data := make([]byte, len(f.Data))
copy(data, f.Data)
select {
case h.rcvQueue <- data:
default:
h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data))
}
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive() ([]byte, error) {
select {
case data := <-h.rcvQueue:
return data, nil
case <-h.closed:
return nil, h.closeErr
}
}
func (h *datagramQueue) CloseWithError(e error) {
h.closeErr = e
close(h.closed)
}

98
datagram_queue_test.go Normal file
View file

@ -0,0 +1,98 @@
package quic
import (
"errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Datagram Queue", func() {
var queue *datagramQueue
var queued chan struct{}
BeforeEach(func() {
queued = make(chan struct{}, 100)
queue = newDatagramQueue(func() {
queued <- struct{}{}
}, utils.DefaultLogger)
})
Context("sending", func() {
It("returns nil when there's no datagram to send", func() {
Expect(queue.Get()).To(BeNil())
})
It("queues a datagram", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})).To(Succeed())
}()
Eventually(queued).Should(HaveLen(1))
Consistently(done).ShouldNot(BeClosed())
f := queue.Get()
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("foobar")))
Eventually(done).Should(BeClosed())
Expect(queue.Get()).To(BeNil())
})
It("closes", func() {
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")})
}()
Consistently(errChan).ShouldNot(Receive())
queue.CloseWithError(errors.New("test error"))
Eventually(errChan).Should(Receive(MatchError("test error")))
})
})
Context("receiving", func() {
It("receives DATAGRAM frames", func() {
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")})
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")})
data, err := queue.Receive()
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foo")))
data, err = queue.Receive()
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("bar")))
})
It("blocks until a frame is received", func() {
c := make(chan []byte, 1)
go func() {
defer GinkgoRecover()
data, err := queue.Receive()
Expect(err).ToNot(HaveOccurred())
c <- data
}()
Consistently(c).ShouldNot(Receive())
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")})
Eventually(c).Should(Receive(Equal([]byte("foobar"))))
})
It("closes", func() {
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
_, err := queue.Receive()
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
queue.CloseWithError(errors.New("test error"))
Eventually(errChan).Should(Receive(MatchError("test error")))
})
})
})

View file

@ -33,7 +33,7 @@ func Fuzz(data []byte) int {
encLevel := toEncLevel(data[0])
data = data[PrefixLen:]
parser := wire.NewFrameParser(version)
parser := wire.NewFrameParser(true, version)
parser.SetAckDelayExponent(protocol.DefaultAckDelayExponent)
r := bytes.NewReader(data)

View file

@ -250,7 +250,7 @@ func (c *client) doRequest(
return nil, newConnError(errorGeneralProtocolError, err)
}
connState := qtls.ToTLSConnectionState(c.session.ConnectionState())
connState := qtls.ToTLSConnectionState(c.session.ConnectionState().TLS)
res := &http.Response{
Proto: "HTTP/3",
ProtoMajor: 3,

View file

@ -16,7 +16,6 @@ import (
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
@ -240,7 +239,7 @@ var _ = Describe("Client", func() {
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{}),
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
@ -410,7 +409,7 @@ var _ = Describe("Client", func() {
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{})
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -473,7 +472,7 @@ var _ = Describe("Client", func() {
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{})
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
@ -499,7 +498,7 @@ var _ = Describe("Client", func() {
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
sess.EXPECT().ConnectionState().Return(qtls.ConnectionState{})
sess.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))

View file

@ -0,0 +1,141 @@
package self_test
import (
"context"
"encoding/binary"
"fmt"
mrand "math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Datagram test", func() {
for _, v := range protocol.SupportedVersions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
const num = 100
var (
proxy *quicproxy.QuicProxy
serverConn, clientConn *net.UDPConn
dropped, total int32
)
startServerAndProxy := func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
ln, err := quic.Listen(
serverConn,
getTLSConfig(),
getQuicConfig(&quic.Config{
EnableDatagrams: true,
Versions: []protocol.VersionNumber{version},
}),
)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().SupportsDatagrams).To(BeTrue())
var wg sync.WaitGroup
wg.Add(num)
for i := 0; i < num; i++ {
go func(i int) {
defer GinkgoRecover()
defer wg.Done()
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(i))
Expect(sess.SendMessage(b)).To(Succeed())
}(i)
}
wg.Wait()
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
// drop 10% of Short Header packets sent from the server
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
if dir == quicproxy.DirectionIncoming {
return false
}
// don't drop Long Header packets
if packet[0]&0x80 == 1 {
return false
}
drop := mrand.Int()%10 == 0
if drop {
atomic.AddInt32(&dropped, 1)
}
atomic.AddInt32(&total, 1)
return drop
},
})
Expect(err).ToNot(HaveOccurred())
}
BeforeEach(func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
clientConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
It("sends datagrams", func() {
startServerAndProxy()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
Expect(err).ToNot(HaveOccurred())
sess, err := quic.Dial(
clientConn,
raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
EnableDatagrams: true,
Versions: []protocol.VersionNumber{version},
}),
)
Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().SupportsDatagrams).To(BeTrue())
var counter int
for {
// Close the session if no message is received for 100 ms.
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
sess.CloseWithError(0, "")
})
if _, err := sess.ReceiveMessage(); err != nil {
break
}
timer.Stop()
counter++
}
numDropped := int(atomic.LoadInt32(&dropped))
expVal := num - numDropped
fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, atomic.LoadInt32(&total))
fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num)
Expect(counter).To(And(
BeNumerically(">", expVal*9/10),
BeNumerically("<", num),
))
})
})
}
})

View file

@ -170,7 +170,7 @@ var _ = Describe("Handshake tests", func() {
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRData))
Expect(sess.ConnectionState().CipherSuite).To(Equal(suiteID))
Expect(sess.ConnectionState().TLS.CipherSuite).To(Equal(suiteID))
Expect(sess.CloseWithError(0, "")).To(Succeed())
})
}
@ -369,7 +369,7 @@ var _ = Describe("Handshake tests", func() {
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
cs := sess.ConnectionState()
Expect(cs.NegotiatedProtocol).To(Equal(alpn))
Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
close(done)
}()
@ -381,7 +381,7 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer sess.CloseWithError(0, "")
cs := sess.ConnectionState()
Expect(cs.NegotiatedProtocol).To(Equal(alpn))
Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
Eventually(done).Should(BeClosed())
Expect(ln.Close()).To(Succeed())
})

View file

@ -65,11 +65,11 @@ var _ = Describe("TLS session resumption", func() {
Expect(err).ToNot(HaveOccurred())
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(sess.ConnectionState().DidResume).To(BeFalse())
Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse())
serverSess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverSess.ConnectionState().DidResume).To(BeFalse())
Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse())
sess, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@ -78,11 +78,11 @@ var _ = Describe("TLS session resumption", func() {
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(sess.ConnectionState().DidResume).To(BeTrue())
Expect(sess.ConnectionState().TLS.DidResume).To(BeTrue())
serverSess, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverSess.ConnectionState().DidResume).To(BeTrue())
Expect(serverSess.ConnectionState().TLS.DidResume).To(BeTrue())
})
It("doesn't use session resumption, if the config disables it", func() {
@ -104,11 +104,11 @@ var _ = Describe("TLS session resumption", func() {
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(sess.ConnectionState().DidResume).To(BeFalse())
Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse())
serverSess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverSess.ConnectionState().DidResume).To(BeFalse())
Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse())
sess, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@ -116,10 +116,10 @@ var _ = Describe("TLS session resumption", func() {
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().DidResume).To(BeFalse())
Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse())
serverSess, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverSess.ConnectionState().DidResume).To(BeFalse())
Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse())
})
})

View file

@ -15,6 +15,7 @@ import (
"math/big"
mrand "math/rand"
"os"
"strconv"
"sync"
"testing"
"time"
@ -317,6 +318,15 @@ func debugLog() bool {
return len(logFileName) > 0
}
func scaleDuration(d time.Duration) time.Duration {
scaleFactor := 1
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
scaleFactor = f
}
Expect(scaleFactor).ToNot(BeZero())
return time.Duration(scaleFactor) * d
}
func TestSelf(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Self integration tests")

View file

@ -8,9 +8,7 @@ import (
"io/ioutil"
mrand "math/rand"
"net"
"os"
"runtime/pprof"
"strconv"
"strings"
"sync/atomic"
"time"
@ -181,15 +179,6 @@ var _ = Describe("Timeout tests", func() {
Context("timing out at the right time", func() {
var idleTimeout time.Duration
scaleDuration := func(d time.Duration) time.Duration {
scaleFactor := 1
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
scaleFactor = f
}
Expect(scaleFactor).ToNot(BeZero())
return time.Duration(scaleFactor) * d
}
BeforeEach(func() {
idleTimeout = scaleDuration(100 * time.Millisecond)
})

View file

@ -86,7 +86,7 @@ var _ = Describe("0-RTT", func() {
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testdata))
Expect(sess.ConnectionState().Used0RTT).To(Equal(expect0RTT))
Expect(sess.ConnectionState().TLS.Used0RTT).To(Equal(expect0RTT))
close(done)
}()
@ -101,7 +101,7 @@ var _ = Describe("0-RTT", func() {
_, err = str.Write(testdata)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(sess.ConnectionState().Used0RTT).To(Equal(expect0RTT))
Expect(sess.ConnectionState().TLS.Used0RTT).To(Equal(expect0RTT))
Eventually(done).Should(BeClosed())
}

View file

@ -141,8 +141,6 @@ type StreamError interface {
ErrorCode() ErrorCode
}
type ConnectionState = handshake.ConnectionState
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
@ -189,6 +187,13 @@ type Session interface {
// It blocks until the handshake completes.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
// SendMessage sends a message as a datagram.
// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/.
SendMessage([]byte) error
// ReceiveMessage gets a message received in a datagram.
// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/.
ReceiveMessage() ([]byte, error)
}
// An EarlySession is a session that is handshaking.
@ -261,7 +266,16 @@ type Config struct {
StatelessResetKey []byte
// KeepAlive defines whether this peer will periodically send a packet to keep the connection alive.
KeepAlive bool
Tracer logging.Tracer
// See https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/.
// Datagrams will only be available when both peers enable datagram support.
EnableDatagrams bool
Tracer logging.Tracer
}
// ConnectionState records basic details about a QUIC connection
type ConnectionState struct {
TLS handshake.ConnectionState
SupportsDatagrams bool
}
// A Listener for incoming QUIC connections

View file

@ -30,9 +30,9 @@ var _ = Describe("Stream Flow controller", func() {
Context("Constructor", func() {
rttStats := &utils.RTTStats{}
receiveWindow := protocol.ByteCount(2000)
maxReceiveWindow := protocol.ByteCount(3000)
sendWindow := protocol.ByteCount(4000)
const receiveWindow protocol.ByteCount = 2000
const maxReceiveWindow protocol.ByteCount = 3000
const sendWindow protocol.ByteCount = 4000
It("sets the send and receive windows", func() {
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
@ -50,7 +50,7 @@ var _ = Describe("Stream Flow controller", func() {
queued = true
}
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
cc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, func() {}, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
fc.AddBytesRead(receiveWindow)
Expect(queued).To(BeTrue())

View file

@ -23,6 +23,10 @@ func ConvertFrame(frame wire.Frame) logging.Frame {
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}

View file

@ -34,6 +34,13 @@ var _ = Describe("CRYPTO frame", func() {
Expect(sf.Fin).To(BeTrue())
})
It("converts DATAGRAM frames", func() {
f := ConvertFrame(&wire.DatagramFrame{Data: []byte("foobar")})
Expect(f).To(BeAssignableToTypeOf(&logging.DatagramFrame{}))
df := f.(*logging.DatagramFrame)
Expect(df.Length).To(Equal(logging.ByteCount(6)))
})
It("converts other frames", func() {
f := ConvertFrame(&wire.MaxDataFrame{MaximumData: 1234})
Expect(f).To(BeAssignableToTypeOf(&logging.MaxDataFrame{}))

View file

@ -197,6 +197,21 @@ func (mr *MockEarlySessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlySession)(nil).OpenUniStreamSync), arg0)
}
// ReceiveMessage mocks base method
func (m *MockEarlySession) ReceiveMessage() ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReceiveMessage")
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReceiveMessage indicates an expected call of ReceiveMessage
func (mr *MockEarlySessionMockRecorder) ReceiveMessage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlySession)(nil).ReceiveMessage))
}
// RemoteAddr mocks base method
func (m *MockEarlySession) RemoteAddr() net.Addr {
m.ctrl.T.Helper()
@ -210,3 +225,17 @@ func (mr *MockEarlySessionMockRecorder) RemoteAddr() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlySession)(nil).RemoteAddr))
}
// SendMessage mocks base method
func (m *MockEarlySession) SendMessage(arg0 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendMessage", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SendMessage indicates an expected call of SendMessage
func (mr *MockEarlySessionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlySession)(nil).SendMessage), arg0)
}

View file

@ -132,6 +132,15 @@ const MaxPostHandshakeCryptoFrameSize = 1000
// but must ensure that a maximum size ACK frame fits into one packet.
const MaxAckFrameSize ByteCount = 1000
// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame as defined in
// https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/.
// The size is chosen such that a DATAGRAM frame fits into a QUIC packet.
const MaxDatagramFrameSize ByteCount = 1200
// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames.
// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/.
const DatagramRcvQueueLen = 128
// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame.
// It also serves as a limit for the packet history.
// If at any point we keep track of more ranges, old ranges are discarded.

View file

@ -44,11 +44,14 @@ const (
)
// A ByteCount in QUIC
type ByteCount uint64
type ByteCount int64
// MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = ByteCount(1<<62 - 1)
// InvalidByteCount is an invalid byte count
const InvalidByteCount ByteCount = -1
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint64

View file

@ -0,0 +1,85 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A DatagramFrame is a DATAGRAM frame
type DatagramFrame struct {
DataLenPresent bool
Data []byte
}
func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &DatagramFrame{}
f.DataLenPresent = typeByte&0x1 > 0
var length uint64
if f.DataLenPresent {
var err error
len, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
if len > uint64(r.Len()) {
return nil, io.EOF
}
length = len
} else {
length = uint64(r.Len())
}
f.Data = make([]byte, length)
if _, err := io.ReadFull(r, f.Data); err != nil {
return nil, err
}
return f, nil
}
func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := uint8(0x30)
if f.DataLenPresent {
typeByte ^= 0x1
}
b.WriteByte(typeByte)
if f.DataLenPresent {
utils.WriteVarInt(b, uint64(len(f.Data)))
}
b.Write(f.Data)
return nil
}
// MaxDataLen returns the maximum data length
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
headerLen := protocol.ByteCount(1)
if f.DataLenPresent {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen++
}
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if f.DataLenPresent && utils.VarIntLen(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// Length of a written frame
func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
length := 1 + protocol.ByteCount(len(f.Data))
if f.DataLenPresent {
length += utils.VarIntLen(uint64(len(f.Data)))
}
return length
}

View file

@ -0,0 +1,153 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM frame", func() {
Context("when parsing", func() {
It("parses a frame containing a length", func() {
data := []byte{0x30 ^ 0x1}
data = append(data, encodeVarInt(0x6)...) // length
data = append(data, []byte("foobar")...)
r := bytes.NewReader(data)
frame, err := parseDatagramFrame(r, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.Data).To(Equal([]byte("foobar")))
Expect(frame.DataLenPresent).To(BeTrue())
Expect(r.Len()).To(BeZero())
})
It("parses a frame without length", func() {
data := []byte{0x30}
data = append(data, []byte("Lorem ipsum dolor sit amet")...)
r := bytes.NewReader(data)
frame, err := parseDatagramFrame(r, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet")))
Expect(frame.DataLenPresent).To(BeFalse())
Expect(r.Len()).To(BeZero())
})
It("errors when the length is longer than the rest of the frame", func() {
data := []byte{0x30 ^ 0x1}
data = append(data, encodeVarInt(0x6)...) // length
data = append(data, []byte("fooba")...)
r := bytes.NewReader(data)
_, err := parseDatagramFrame(r, versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
})
It("errors on EOFs", func() {
data := []byte{0x30 ^ 0x1}
data = append(data, encodeVarInt(6)...) // length
data = append(data, []byte("foobar")...)
_, err := parseDatagramFrame(bytes.NewReader(data), versionIETFFrames)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := parseDatagramFrame(bytes.NewReader(data[0:i]), versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("when writing", func() {
It("writes a frame with length", func() {
f := &DatagramFrame{
DataLenPresent: true,
Data: []byte("foobar"),
}
buf := &bytes.Buffer{}
Expect(f.Write(buf, versionIETFFrames)).To(Succeed())
expected := []byte{0x30 ^ 0x1}
expected = append(expected, encodeVarInt(0x6)...)
expected = append(expected, []byte("foobar")...)
Expect(buf.Bytes()).To(Equal(expected))
})
It("writes a frame without length", func() {
f := &DatagramFrame{Data: []byte("Lorem ipsum")}
buf := &bytes.Buffer{}
Expect(f.Write(buf, versionIETFFrames)).To(Succeed())
expected := []byte{0x30}
expected = append(expected, []byte("Lorem ipsum")...)
Expect(buf.Bytes()).To(Equal(expected))
})
})
Context("length", func() {
It("has the right length for a frame with length", func() {
f := &DatagramFrame{
DataLenPresent: true,
Data: []byte("foobar"),
}
Expect(f.Length(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(6) + 6))
})
It("has the right length for a frame without length", func() {
f := &DatagramFrame{Data: []byte("foobar")}
Expect(f.Length(versionIETFFrames)).To(Equal(protocol.ByteCount(1 + 6)))
})
})
Context("max data length", func() {
const maxSize = 3000
It("returns a data length such that the resulting frame has the right size, if data length is not present", func() {
data := make([]byte, maxSize)
f := &DatagramFrame{}
b := &bytes.Buffer{}
for i := 1; i < 3000; i++ {
b.Reset()
f.Data = nil
maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames)
if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written
// check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size
f.Data = []byte{0}
Expect(f.Write(b, versionIETFFrames)).To(Succeed())
Expect(b.Len()).To(BeNumerically(">", i))
continue
}
f.Data = data[:int(maxDataLen)]
Expect(f.Write(b, versionIETFFrames)).To(Succeed())
Expect(b.Len()).To(Equal(i))
}
})
It("always returns a data length such that the resulting frame has the right size, if data length is present", func() {
data := make([]byte, maxSize)
f := &DatagramFrame{DataLenPresent: true}
b := &bytes.Buffer{}
var frameOneByteTooSmallCounter int
for i := 1; i < 3000; i++ {
b.Reset()
f.Data = nil
maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames)
if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written
// check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size
f.Data = []byte{0}
Expect(f.Write(b, versionIETFFrames)).To(Succeed())
Expect(b.Len()).To(BeNumerically(">", i))
continue
}
f.Data = data[:int(maxDataLen)]
Expect(f.Write(b, versionIETFFrames)).To(Succeed())
// There's *one* pathological case, where a data length of x can be encoded into 1 byte
// but a data lengths of x+1 needs 2 bytes
// In that case, it's impossible to create a STREAM frame of the desired size
if b.Len() == i-1 {
frameOneByteTooSmallCounter++
continue
}
Expect(b.Len()).To(Equal(i))
}
Expect(frameOneByteTooSmallCounter).To(Equal(1))
})
})
})

View file

@ -13,12 +13,17 @@ import (
type frameParser struct {
ackDelayExponent uint8
supportsDatagrams bool
version protocol.VersionNumber
}
// NewFrameParser creates a new frame parser.
func NewFrameParser(v protocol.VersionNumber) FrameParser {
return &frameParser{version: v}
func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser {
return &frameParser{
supportsDatagrams: supportsDatagrams,
version: v,
}
}
// ParseNextFrame parses the next frame
@ -87,6 +92,12 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protoc
frame, err = parseConnectionCloseFrame(r, p.version)
case 0x1e:
frame, err = parseHandshakeDoneFrame(r, p.version)
case 0x30, 0x31:
if p.supportsDatagrams {
frame, err = parseDatagramFrame(r, p.version)
break
}
fallthrough
default:
err = errors.New("unknown frame type")
}

View file

@ -18,7 +18,7 @@ var _ = Describe("Frame parsing", func() {
BeforeEach(func() {
buf = &bytes.Buffer{}
parser = NewFrameParser(versionIETFFrames)
parser = NewFrameParser(true, versionIETFFrames)
})
It("returns nil if there's nothing more to read", func() {
@ -280,6 +280,24 @@ var _ = Describe("Frame parsing", func() {
Expect(frame).To(Equal(f))
})
It("unpacks DATAGRAM frames", func() {
f := &DatagramFrame{Data: []byte("foobar")}
buf := &bytes.Buffer{}
Expect(f.Write(buf, versionIETFFrames)).To(Succeed())
frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("errors when DATAGRAM frames are not supported", func() {
parser = NewFrameParser(false, versionIETFFrames)
f := &DatagramFrame{Data: []byte("foobar")}
buf := &bytes.Buffer{}
Expect(f.Write(buf, versionIETFFrames)).To(Succeed())
_, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT)
Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x30): unknown frame type"))
})
It("errors on invalid type", func() {
_, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT)
Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x42): unknown frame type"))
@ -318,6 +336,7 @@ var _ = Describe("Frame parsing", func() {
&PathResponseFrame{},
&ConnectionCloseFrame{},
&HandshakeDoneFrame{},
&DatagramFrame{},
}
var framesSerialized [][]byte

View file

@ -46,11 +46,12 @@ var _ = Describe("Transport Parameters", func() {
MaxAckDelay: 37 * time.Millisecond,
StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00},
ActiveConnectionIDLimit: 123,
MaxDatagramFrameSize: 876,
}
Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: 0xdeadbeef, InitialSourceConnectionID: 0xdecafbad, RetrySourceConnectionID: 0xdeadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00}"))
Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: 0xdeadbeef, InitialSourceConnectionID: 0xdecafbad, RetrySourceConnectionID: 0xdeadc0de, InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37ms, ActiveConnectionIDLimit: 123, StatelessResetToken: 0x112233445566778899aabbccddeeff00, MaxDatagramFrameSize: 876}"))
})
It("has a string representation, if there's no stateless reset token and no Retry source connection id", func() {
It("has a string representation, if there's no stateless reset token, no Retry source connection id and no datagram support", func() {
p := &TransportParameters{
InitialMaxStreamDataBidiLocal: 1234,
InitialMaxStreamDataBidiRemote: 2345,
@ -64,6 +65,7 @@ var _ = Describe("Transport Parameters", func() {
AckDelayExponent: 14,
MaxAckDelay: 37 * time.Second,
ActiveConnectionIDLimit: 89,
MaxDatagramFrameSize: protocol.InvalidByteCount,
}
Expect(p.String()).To(Equal("&wire.TransportParameters{OriginalDestinationConnectionID: 0xdeadbeef, InitialSourceConnectionID: (empty), InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, InitialMaxStreamDataUni: 3456, InitialMaxData: 4567, MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s, ActiveConnectionIDLimit: 89}"))
})
@ -87,6 +89,7 @@ var _ = Describe("Transport Parameters", func() {
AckDelayExponent: 13,
MaxAckDelay: 42 * time.Millisecond,
ActiveConnectionIDLimit: getRandomValue(),
MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()),
}
data := params.Marshal(protocol.PerspectiveServer)
@ -107,6 +110,7 @@ var _ = Describe("Transport Parameters", func() {
Expect(p.AckDelayExponent).To(Equal(uint8(13)))
Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond))
Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit))
Expect(p.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize))
})
It("doesn't marshal a retry_source_connection_id, if no Retry was performed", func() {

View file

@ -42,6 +42,8 @@ const (
activeConnectionIDLimitParameterID transportParameterID = 0xe
initialSourceConnectionIDParameterID transportParameterID = 0xf
retrySourceConnectionIDParameterID transportParameterID = 0x10
// https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/
maxDatagramFrameSizeParameterID transportParameterID = 0x20
)
// PreferredAddress is the value encoding in the preferred_address transport parameter
@ -81,6 +83,8 @@ type TransportParameters struct {
StatelessResetToken *protocol.StatelessResetToken
ActiveConnectionIDLimit uint64
MaxDatagramFrameSize protocol.ByteCount
}
// Unmarshal the transport parameters
@ -96,12 +100,14 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
var parameterIDs []transportParameterID
var (
readAckDelayExponent bool
readMaxAckDelay bool
readOriginalDestinationConnectionID bool
readInitialSourceConnectionID bool
)
p.AckDelayExponent = protocol.DefaultAckDelayExponent
p.MaxAckDelay = protocol.DefaultMaxAckDelay
p.MaxDatagramFrameSize = protocol.InvalidByteCount
for r.Len() > 0 {
paramIDInt, err := utils.ReadVarInt(r)
if err != nil {
@ -118,12 +124,10 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
parameterIDs = append(parameterIDs, paramID)
switch paramID {
case ackDelayExponentParameterID:
readAckDelayExponent = true
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
return err
}
case maxAckDelayParameterID:
readMaxAckDelay = true
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
return err
}
@ -135,7 +139,8 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
initialMaxStreamsUniParameterID,
maxIdleTimeoutParameterID,
maxUDPPayloadSizeParameterID,
activeConnectionIDLimitParameterID:
activeConnectionIDLimitParameterID,
maxDatagramFrameSizeParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
return err
}
@ -185,12 +190,6 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID {
return errors.New("missing original_destination_connection_id")
}
if !readAckDelayExponent {
p.AckDelayExponent = protocol.DefaultAckDelayExponent
}
if !readMaxAckDelay {
p.MaxAckDelay = protocol.DefaultMaxAckDelay
}
if p.MaxUDPPayloadSize == 0 {
p.MaxUDPPayloadSize = protocol.MaxByteCount
}
@ -305,6 +304,8 @@ func (p *TransportParameters) readNumericTransportParameter(
p.MaxAckDelay = maxAckDelay
case activeConnectionIDLimitParameterID:
p.ActiveConnectionIDLimit = val
case maxDatagramFrameSizeParameterID:
p.MaxDatagramFrameSize = protocol.ByteCount(val)
default:
return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID)
}
@ -391,6 +392,9 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
utils.WriteVarInt(b, uint64(p.RetrySourceConnectionID.Len()))
b.Write(p.RetrySourceConnectionID.Bytes())
}
if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize))
}
return b.Bytes()
}
@ -463,6 +467,10 @@ func (p *TransportParameters) String() string {
logString += ", StatelessResetToken: %#x"
logParams = append(logParams, *p.StatelessResetToken)
}
if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
logString += ", MaxDatagramFrameSize: %d"
logParams = append(logParams, p.MaxDatagramFrameSize)
}
logString += "}"
return fmt.Sprintf(logString, logParams...)
}

View file

@ -59,3 +59,8 @@ type StreamFrame struct {
Length ByteCount
Fin bool
}
// A DatagramFrame is a DATAGRAM frame.
type DatagramFrame struct {
Length ByteCount
}

View file

@ -210,6 +210,21 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0)
}
// ReceiveMessage mocks base method
func (m *MockQuicSession) ReceiveMessage() ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReceiveMessage")
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReceiveMessage indicates an expected call of ReceiveMessage
func (mr *MockQuicSessionMockRecorder) ReceiveMessage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicSession)(nil).ReceiveMessage))
}
// RemoteAddr mocks base method
func (m *MockQuicSession) RemoteAddr() net.Addr {
m.ctrl.T.Helper()
@ -224,6 +239,20 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
}
// SendMessage mocks base method
func (m *MockQuicSession) SendMessage(arg0 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendMessage", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SendMessage indicates an expected call of SendMessage
func (mr *MockQuicSessionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQuicSession)(nil).SendMessage), arg0)
}
// destroy mocks base method
func (m *MockQuicSession) destroy(arg0 error) {
m.ctrl.T.Helper()

View file

@ -156,6 +156,7 @@ type packetPacker struct {
pnManager packetNumberManager
framer frameSource
acks ackFrameSource
datagramQueue *datagramQueue
retransmissionQueue *retransmissionQueue
maxPacketSize protocol.ByteCount
@ -175,6 +176,7 @@ func newPacketPacker(
cryptoSetup sealingManager,
framer frameSource,
acks ackFrameSource,
datagramQueue *datagramQueue,
perspective protocol.Perspective,
version protocol.VersionNumber,
) *packetPacker {
@ -185,6 +187,7 @@ func newPacketPacker(
initialStream: initialStream,
handshakeStream: handshakeStream,
retransmissionQueue: retransmissionQueue,
datagramQueue: datagramQueue,
perspective: perspective,
version: version,
framer: framer,
@ -576,10 +579,25 @@ func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload {
payload := &payload{}
var hasDatagram bool
if p.datagramQueue != nil {
if datagram := p.datagramQueue.Get(); datagram != nil {
payload.frames = append(payload.frames, ackhandler.Frame{
Frame: datagram,
// set it to a no-op. Then we won't set the default callback, which would retransmit the frame.
OnLost: func(wire.Frame) {},
})
payload.length += datagram.Length(p.version)
hasDatagram = true
}
}
var ack *wire.AckFrame
hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData()
if ackAllowed {
// TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued
if !hasDatagram && ackAllowed {
ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData)
if ack != nil {
payload.ack = ack

View file

@ -7,17 +7,17 @@ import (
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
@ -30,6 +30,7 @@ var _ = Describe("Packet packer", func() {
var (
packer *packetPacker
retransmissionQueue *retransmissionQueue
datagramQueue *datagramQueue
framer *MockFrameSource
ackFramer *MockAckFrameSource
initialStream *MockCryptoStream
@ -90,6 +91,7 @@ var _ = Describe("Packet packer", func() {
ackFramer = NewMockAckFrameSource(mockCtrl)
sealingManager = NewMockSealingManager(mockCtrl)
pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl)
datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger)
packer = newPacketPacker(
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
@ -102,6 +104,7 @@ var _ = Describe("Packet packer", func() {
sealingManager,
framer,
ackFramer,
datagramQueue,
protocol.PerspectiveServer,
version,
)
@ -537,6 +540,33 @@ var _ = Describe("Packet packer", func() {
Expect(p.buffer.Len()).ToNot(BeZero())
})
It("packs DATAGRAM frames", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil)
f := &wire.DatagramFrame{
DataLenPresent: true,
Data: []byte("foobar"),
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
datagramQueue.AddAndWait(f)
}()
// make sure the DATAGRAM has actually been queued
time.Sleep(scaleDuration(20 * time.Millisecond))
framer.EXPECT().HasData()
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].Frame).To(Equal(f))
Expect(p.buffer.Data).ToNot(BeEmpty())
Eventually(done).Should(BeClosed())
})
It("accounts for the space consumed by control frames", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil)
@ -588,7 +618,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(secondPayloadByte).To(Equal(byte(0)))
// ... followed by the PING
frameParser := wire.NewFrameParser(packer.version)
frameParser := wire.NewFrameParser(false, packer.version)
frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{}))
@ -625,7 +655,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(firstPayloadByte).To(Equal(byte(0)))
// ... followed by the STREAM frame
frameParser := wire.NewFrameParser(packer.version)
frameParser := wire.NewFrameParser(true, packer.version)
frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
@ -1137,7 +1167,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(secondPayloadByte).To(Equal(byte(0)))
// ... followed by the PING
frameParser := wire.NewFrameParser(packer.version)
frameParser := wire.NewFrameParser(false, packer.version)
frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&wire.PingFrame{}))

View file

@ -57,6 +57,8 @@ func (f frame) MarshalJSONObject(enc *gojay.Encoder) {
marshalConnectionCloseFrame(enc, frame)
case *logging.HandshakeDoneFrame:
marshalHandshakeDoneFrame(enc, frame)
case *logging.DatagramFrame:
marshalDatagramFrame(enc, frame)
default:
panic("unknown frame type")
}
@ -218,3 +220,8 @@ func marshalConnectionCloseFrame(enc *gojay.Encoder, f *logging.ConnectionCloseF
func marshalHandshakeDoneFrame(enc *gojay.Encoder, _ *logging.HandshakeDoneFrame) {
enc.StringKey("frame_type", "handshake_done")
}
func marshalDatagramFrame(enc *gojay.Encoder, f *logging.DatagramFrame) {
enc.StringKey("frame_type", "datagram")
enc.Int64Key("length", int64(f.Length))
}

View file

@ -364,4 +364,14 @@ var _ = Describe("Frames", func() {
},
)
})
It("marshals DATAGRAM frames", func() {
check(
&logging.DatagramFrame{Length: 1337},
map[string]interface{}{
"frame_type": "datagram",
"length": 1337,
},
)
})
})

View file

@ -521,7 +521,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)

View file

@ -204,6 +204,8 @@ type session struct {
keepAlivePingSent bool
keepAliveInterval time.Duration
datagramQueue *datagramQueue
logID string
tracer logging.ConnectionTracer
logger utils.Logger
@ -295,6 +297,9 @@ var newSession = func(
InitialSourceConnectionID: srcConnID,
RetrySourceConnectionID: retrySrcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
}
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
@ -333,6 +338,7 @@ var newSession = func(
cs,
s.framer,
s.receivedPacketHandler,
s.datagramQueue,
s.perspective,
s.version,
)
@ -414,6 +420,9 @@ var newClientSession = func(
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
}
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
@ -452,6 +461,7 @@ var newClientSession = func(
cs,
s.framer,
s.receivedPacketHandler,
s.datagramQueue,
s.perspective,
s.version,
)
@ -471,7 +481,7 @@ var newClientSession = func(
func (s *session) preSetup() {
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue(s.version)
s.frameParser = wire.NewFrameParser(s.version)
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version)
s.rttStats = &utils.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData,
@ -501,6 +511,9 @@ func (s *session) preSetup() {
s.sessionCreationTime = now
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame)
if s.config.EnableDatagrams {
s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger)
}
}
// run the session main loop
@ -633,8 +646,15 @@ func (s *session) Context() context.Context {
return s.ctx
}
func (s *session) supportsDatagrams() bool {
return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount
}
func (s *session) ConnectionState() ConnectionState {
return s.cryptoStreamHandler.ConnectionState()
return ConnectionState{
TLS: s.cryptoStreamHandler.ConnectionState(),
SupportsDatagrams: s.supportsDatagrams(),
}
}
// Time when the next keep-alive packet should be sent.
@ -1104,6 +1124,8 @@ func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, d
err = s.handleRetireConnectionIDFrame(frame, destConnID)
case *wire.HandshakeDoneFrame:
err = s.handleHandshakeDoneFrame()
case *wire.DatagramFrame:
err = s.handleDatagramFrame(frame)
default:
err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name())
}
@ -1245,6 +1267,14 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt
return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
}
func (s *session) handleDatagramFrame(f *wire.DatagramFrame) error {
if f.Length(s.version) > protocol.MaxDatagramFrameSize {
return qerr.NewError(qerr.ProtocolViolation, "DATAGRAM frame too large")
}
s.datagramQueue.HandleDatagramFrame(f)
return nil
}
// closeLocal closes the session and send a CONNECTION_CLOSE containing the error
func (s *session) closeLocal(e error) {
s.closeOnce.Do(func() {
@ -1307,6 +1337,9 @@ func (s *session) handleCloseError(closeErr closeError) {
s.streamsMap.CloseWithError(quicErr)
s.connIDManager.Close()
if s.datagramQueue != nil {
s.datagramQueue.CloseWithError(quicErr)
}
if s.tracer != nil {
// timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts)
@ -1731,6 +1764,21 @@ func (s *session) onStreamCompleted(id protocol.StreamID) {
}
}
func (s *session) SendMessage(p []byte) error {
f := &wire.DatagramFrame{DataLenPresent: true}
if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) {
return errors.New("message too large")
}
f.Data = make([]byte, len(p))
copy(f.Data, p)
s.datagramQueue.AddAndWait(f)
return nil
}
func (s *session) ReceiveMessage() ([]byte, error) {
return s.datagramQueue.Receive()
}
func (s *session) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}

View file

@ -39,7 +39,7 @@ var _ = Describe("Streams Map (incoming)", func() {
checkFrameSerialization := func(f wire.Frame) {
b := &bytes.Buffer{}
ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed())
frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT)
frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(f).To(Equal(frame))
}