add a context to Session.Open{Uni}StreamSync

This commit is contained in:
Marten Seemann 2019-06-07 16:19:56 +08:00
parent e63a991950
commit 2b8cece60a
20 changed files with 218 additions and 104 deletions

View file

@ -8,7 +8,7 @@
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
- Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept` and `Session.Accept{Uni}Stream`.
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
## v0.11.0 (2019-04-05)

View file

@ -59,7 +59,7 @@ func clientMain() error {
return err
}
stream, err := session.OpenStreamSync()
stream, err := session.OpenStreamSync(context.Background())
if err != nil {
return err
}

View file

@ -2,6 +2,7 @@ package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@ -100,7 +101,7 @@ func (c *client) dial() error {
func (c *client) setupSession() error {
// open the control stream
str, err := c.session.OpenUniStreamSync()
str, err := c.session.OpenUniStream()
if err != nil {
return err
}
@ -138,7 +139,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, c.handshakeErr
}
str, err := c.session.OpenStreamSync()
str, err := c.session.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}

View file

@ -3,6 +3,7 @@ package http3
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
@ -126,8 +127,8 @@ var _ = Describe("Client", func() {
testErr := errors.New("stream open error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockSession(mockCtrl)
session.EXPECT().OpenUniStreamSync().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenStreamSync().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
@ -169,7 +170,7 @@ var _ = Describe("Client", func() {
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockSession(mockCtrl)
sess.EXPECT().OpenUniStreamSync().Return(controlStr, nil).MaxTimes(1)
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return sess, nil
}
@ -179,7 +180,7 @@ var _ = Describe("Client", func() {
})
It("sends a request", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
@ -200,7 +201,7 @@ var _ = Describe("Client", func() {
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
@ -234,7 +235,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
strBuf = &bytes.Buffer{}
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
@ -295,7 +296,7 @@ var _ = Describe("Client", func() {
})
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
@ -310,7 +311,7 @@ var _ = Describe("Client", func() {
It("doesn't add gzip if the header disable it", func() {
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
@ -324,7 +325,7 @@ var _ = Describe("Client", func() {
})
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
@ -348,7 +349,7 @@ var _ = Describe("Client", func() {
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))

View file

@ -2,6 +2,7 @@ package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
@ -91,8 +92,8 @@ var _ = Describe("RoundTripper", func() {
testErr := errors.New("test err")
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync().Return(nil, testErr)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(testErr))
@ -128,8 +129,8 @@ var _ = Describe("RoundTripper", func() {
It("reuses existing clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync().Return(nil, testErr).Times(2)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())

View file

@ -128,7 +128,7 @@ func (s *Server) handleConn(sess quic.Session) {
decoder := qpack.NewDecoder(nil)
// send a SETTINGS frame
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStream()
if err != nil {
s.logger.Debugf("Opening the control stream failed.")
return

View file

@ -41,7 +41,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
if _, err = str.Write(testserver.PRData); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("Stream %d was reset with error code %d", str.StreamID(), str.StreamID())))
@ -203,7 +203,7 @@ var _ = Describe("Stream Cancelations", func() {
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel about 2/3 of the streams
if rand.Int31()%3 != 0 {
@ -234,7 +234,7 @@ var _ = Describe("Stream Cancelations", func() {
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// only write some data from about 1/3 of the streams, then cancel
if rand.Int31()%3 != 0 {
@ -273,7 +273,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel about half of the streams
if rand.Int31()%2 == 0 {
@ -347,7 +347,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel about half of the streams
length := len(testserver.PRData)

View file

@ -47,7 +47,7 @@ var _ = Describe("Bidirectional streams", func() {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.OpenStreamSync()
str, err := sess.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
data := testserver.GeneratePRData(25 * i)
go func() {

View file

@ -42,7 +42,7 @@ var _ = Describe("Unidirectional Streams", func() {
runSendingPeer := func(sess quic.Session) {
for i := 0; i < numStreams; i++ {
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()

View file

@ -143,7 +143,7 @@ type Session interface {
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the session was closed due to a timeout, Timeout() will be true.
OpenStreamSync() (Stream, error)
OpenStreamSync(context.Context) (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, Temporary() will be true.
@ -153,7 +153,7 @@ type Session interface {
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the session was closed due to a timeout, Timeout() will be true.
OpenUniStreamSync() (SendStream, error)
OpenUniStreamSync(context.Context) (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.

View file

@ -154,18 +154,18 @@ func (mr *MockSessionMockRecorder) OpenStream() *gomock.Call {
}
// OpenStreamSync mocks base method
func (m *MockSession) OpenStreamSync() (quic_go.Stream, error) {
func (m *MockSession) OpenStreamSync(arg0 context.Context) (quic_go.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockSessionMockRecorder) OpenStreamSync() *gomock.Call {
func (mr *MockSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync), arg0)
}
// OpenUniStream mocks base method
@ -184,18 +184,18 @@ func (mr *MockSessionMockRecorder) OpenUniStream() *gomock.Call {
}
// OpenUniStreamSync mocks base method
func (m *MockSession) OpenUniStreamSync() (quic_go.SendStream, error) {
func (m *MockSession) OpenUniStreamSync(arg0 context.Context) (quic_go.SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(quic_go.SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
func (mr *MockSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync), arg0)
}
// RemoteAddr mocks base method

View file

@ -167,18 +167,18 @@ func (mr *MockQuicSessionMockRecorder) OpenStream() *gomock.Call {
}
// OpenStreamSync mocks base method
func (m *MockQuicSession) OpenStreamSync() (Stream, error) {
func (m *MockQuicSession) OpenStreamSync(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockQuicSessionMockRecorder) OpenStreamSync() *gomock.Call {
func (mr *MockQuicSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync), arg0)
}
// OpenUniStream mocks base method
@ -197,18 +197,18 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStream() *gomock.Call {
}
// OpenUniStreamSync mocks base method
func (m *MockQuicSession) OpenUniStreamSync() (SendStream, error) {
func (m *MockQuicSession) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0)
}
// RemoteAddr mocks base method

View file

@ -153,18 +153,18 @@ func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call {
}
// OpenStreamSync mocks base method
func (m *MockStreamManager) OpenStreamSync() (Stream, error) {
func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call {
func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0)
}
// OpenUniStream mocks base method
@ -183,18 +183,18 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call {
}
// OpenUniStreamSync mocks base method
func (m *MockStreamManager) OpenUniStreamSync() (SendStream, error) {
func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync() *gomock.Call {
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0)
}
// UpdateLimits mocks base method

View file

@ -37,8 +37,8 @@ type streamManager interface {
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenUniStream() (SendStream, error)
OpenStreamSync() (Stream, error)
OpenUniStreamSync() (SendStream, error)
OpenStreamSync(context.Context) (Stream, error)
OpenUniStreamSync(context.Context) (SendStream, error)
AcceptStream(context.Context) (Stream, error)
AcceptUniStream(context.Context) (ReceiveStream, error)
DeleteStream(protocol.StreamID) error
@ -1246,16 +1246,16 @@ func (s *session) OpenStream() (Stream, error) {
return s.streamsMap.OpenStream()
}
func (s *session) OpenStreamSync() (Stream, error) {
return s.streamsMap.OpenStreamSync()
func (s *session) OpenStreamSync(ctx context.Context) (Stream, error) {
return s.streamsMap.OpenStreamSync(ctx)
}
func (s *session) OpenUniStream() (SendStream, error) {
return s.streamsMap.OpenUniStream()
}
func (s *session) OpenUniStreamSync() (SendStream, error) {
return s.streamsMap.OpenUniStreamSync()
func (s *session) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
return s.streamsMap.OpenUniStreamSync(ctx)
}
func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController {

View file

@ -1423,8 +1423,8 @@ var _ = Describe("Session", func() {
It("opens streams synchronously", func() {
mstr := NewMockStreamI(mockCtrl)
streamManager.EXPECT().OpenStreamSync().Return(mstr, nil)
str, err := sess.OpenStreamSync()
streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
str, err := sess.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})
@ -1439,8 +1439,8 @@ var _ = Describe("Session", func() {
It("opens unidirectional streams synchronously", func() {
mstr := NewMockSendStreamI(mockCtrl)
streamManager.EXPECT().OpenUniStreamSync().Return(mstr, nil)
str, err := sess.OpenUniStreamSync()
streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})

View file

@ -109,8 +109,8 @@ func (m *streamsMap) OpenStream() (Stream, error) {
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
}
func (m *streamsMap) OpenStreamSync() (Stream, error) {
str, err := m.outgoingBidiStreams.OpenStreamSync()
func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
str, err := m.outgoingBidiStreams.OpenStreamSync(ctx)
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
}
@ -119,8 +119,8 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) {
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
}
func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
str, err := m.outgoingUniStreams.OpenStreamSync()
func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
str, err := m.outgoingUniStreams.OpenStreamSync(ctx)
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
}

View file

@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -14,10 +15,12 @@ import (
type outgoingBidiStreamsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]streamI
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@ -34,6 +37,7 @@ func newOutgoingBidiStreamsMap(
) *outgoingBidiStreamsMap {
return &outgoingBidiStreamsMap{
streams: make(map[protocol.StreamNum]streamI),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@ -57,7 +61,7 @@ func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) {
return m.openStream(), nil
}
func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -65,17 +69,32 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@ -86,7 +105,7 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@ -159,9 +178,15 @@ func (m *outgoingBidiStreamsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@ -172,7 +197,9 @@ func (m *outgoingBidiStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}

View file

@ -1,6 +1,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -12,10 +13,12 @@ import (
type outgoingItemsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]item
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@ -32,6 +35,7 @@ func newOutgoingItemsMap(
) *outgoingItemsMap {
return &outgoingItemsMap{
streams: make(map[protocol.StreamNum]item),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@ -55,7 +59,7 @@ func (m *outgoingItemsMap) OpenStream() (item, error) {
return m.openStream(), nil
}
func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -63,17 +67,32 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@ -84,7 +103,7 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@ -157,9 +176,15 @@ func (m *outgoingItemsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@ -170,7 +195,9 @@ func (m *outgoingItemsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}

View file

@ -1,6 +1,7 @@
package quic
import (
"context"
"errors"
"github.com/golang/mock/gomock"
@ -106,12 +107,19 @@ var _ = Describe("Streams Map (outgoing)", func() {
expectTooManyStreamsError(err)
})
It("returns immediately when called with a canceled context", func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.OpenStreamSync(ctx)
Expect(err).To(MatchError("context canceled"))
})
It("blocks until a stream can be opened synchronously", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
close(done)
@ -122,12 +130,34 @@ var _ = Describe("Streams Map (outgoing)", func() {
Eventually(done).Should(BeClosed())
})
It("unblocks when the context is canceled", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
// make sure that the next stream openend is stream 1
m.SetMaxStream(1000)
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
})
It("opens streams in the right order", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
close(done1)
@ -136,7 +166,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
close(done2)
@ -155,20 +185,20 @@ var _ = Describe("Streams Map (outgoing)", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
done <- struct{}{}
}()
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
done <- struct{}{}
}()
Consistently(done).ShouldNot(Receive())
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).To(MatchError("test done"))
done <- struct{}{}
}()
@ -188,7 +218,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
openedSync := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
close(openedSync)
@ -229,7 +259,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).To(MatchError(testErr))
close(done)
}()

View file

@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -14,10 +15,12 @@ import (
type outgoingUniStreamsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]sendStreamI
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@ -34,6 +37,7 @@ func newOutgoingUniStreamsMap(
) *outgoingUniStreamsMap {
return &outgoingUniStreamsMap{
streams: make(map[protocol.StreamNum]sendStreamI),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@ -57,7 +61,7 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
return m.openStream(), nil
}
func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -65,17 +69,32 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@ -86,7 +105,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@ -159,9 +178,15 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@ -172,7 +197,9 @@ func (m *outgoingUniStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}