add a context to Session.Accept{Uni}Stream

This commit is contained in:
Marten Seemann 2019-05-28 17:32:18 +01:00
parent f74082b2fb
commit 5550ba2c3b
28 changed files with 140 additions and 105 deletions

View file

@ -8,7 +8,7 @@
- Enforce application protocol negotiation (via `tls.Config.NextProtos`). - Enforce application protocol negotiation (via `tls.Config.NextProtos`).
- Use a varint for error codes. - Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace). - Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept`. - Add a context to `Listener.Accept` and `Session.Accept{Uni}Stream`.
## v0.11.0 (2019-04-05) ## v0.11.0 (2019-04-05)

View file

@ -67,7 +67,7 @@ func init() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(handshakeChan) close(handshakeChan)
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
buf := &bytes.Buffer{} buf := &bytes.Buffer{}

View file

@ -40,7 +40,7 @@ func echoServer() error {
if err != nil { if err != nil {
return err return err
} }
stream, err := sess.AcceptStream() stream, err := sess.AcceptStream(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -138,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) {
str.Write(buf.Bytes()) str.Write(buf.Bytes())
for { for {
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
if err != nil { if err != nil {
s.logger.Debugf("Accepting stream failed: %s", err) s.logger.Debugf("Accepting stream failed: %s", err)
return return

View file

@ -76,7 +76,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cancel around 2/3 of the streams // cancel around 2/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
@ -120,7 +120,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// only read some data from about 1/3 of the streams // only read some data from about 1/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
@ -168,7 +168,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
if err != nil { if err != nil {
@ -304,7 +304,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cancel around half of the streams // cancel around half of the streams
if rand.Int31()%2 == 0 { if rand.Int31()%2 == 0 {
@ -383,7 +383,7 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
r := io.Reader(str) r := io.Reader(str)

View file

@ -52,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer cl.Close() defer cl.Close()
str, err := cl.AcceptStream() str, err := cl.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -30,7 +30,7 @@ var _ = Describe("Stream deadline tests", func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept(context.Background()) sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverStr, err = sess.AcceptStream() serverStr, err = sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = serverStr.Read([]byte{0}) _, err = serverStr.Read([]byte{0})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -110,7 +110,7 @@ var _ = Describe("Drop Tests", func() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer sess.Close() defer sess.Close()
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := uint8(1); i <= numMessages; i++ { for i := uint8(1); i <= numMessages; i++ {
b := []byte{0} b := []byte{0}

View file

@ -61,7 +61,7 @@ var _ = Describe("Handshake drop tests", func() {
sess, err := ln.Accept(context.Background()) sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer sess.Close() defer sess.Close()
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6) b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
@ -107,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() {
&quic.Config{Versions: []protocol.VersionNumber{version}}, &quic.Config{Versions: []protocol.VersionNumber{version}},
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6) b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)

View file

@ -159,7 +159,7 @@ var _ = Describe("Handshake tests", func() {
errChan := make(chan error) errChan := make(chan error)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := sess.AcceptStream() _, err := sess.AcceptStream(context.Background())
errChan <- err errChan <- err
}() }()
Eventually(errChan).Should(Receive(&err)) Eventually(errChan).Should(Receive(&err))

View file

@ -51,7 +51,7 @@ var _ = Describe("Multiplexing", func() {
&quic.Config{Versions: []protocol.VersionNumber{version}}, &quic.Config{Versions: []protocol.VersionNumber{version}},
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -68,7 +68,7 @@ var _ = Describe("non-zero RTT", func() {
&quic.Config{Versions: []protocol.VersionNumber{version}}, &quic.Config{Versions: []protocol.VersionNumber{version}},
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -64,7 +64,7 @@ var _ = Describe("Stateless Resets", func() {
}, },
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data := make([]byte, 6) data := make([]byte, 6)
_, err = str.Read(data) _, err = str.Read(data)

View file

@ -71,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
str, err := sess.AcceptStream() str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()

View file

@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() {
&quic.Config{IdleTimeout: idleTimeout}, &quic.Config{IdleTimeout: idleTimeout},
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
strIn, err := sess.AcceptStream() strIn, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
strOut, err := sess.OpenStream() strOut, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() {
checkTimeoutError(err) checkTimeoutError(err)
_, err = sess.OpenUniStream() _, err = sess.OpenUniStream()
checkTimeoutError(err) checkTimeoutError(err)
_, err = sess.AcceptStream() _, err = sess.AcceptStream(context.Background())
checkTimeoutError(err) checkTimeoutError(err)
_, err = sess.AcceptUniStream() _, err = sess.AcceptUniStream(context.Background())
checkTimeoutError(err) checkTimeoutError(err)
}) })
@ -148,7 +148,7 @@ var _ = Describe("Timeout tests", func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept(context.Background()) sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.AcceptStream() // blocks until the session is closed sess.AcceptStream(context.Background()) // blocks until the session is closed
close(serverSessionClosed) close(serverSessionClosed)
}() }()
@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := sess.AcceptStream() _, err := sess.AcceptStream(context.Background())
checkTimeoutError(err) checkTimeoutError(err)
close(done) close(done)
}() }()
@ -189,7 +189,7 @@ var _ = Describe("Timeout tests", func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept(context.Background()) sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.AcceptStream() // blocks until the session is closed sess.AcceptStream(context.Background()) // blocks until the session is closed
close(serverSessionClosed) close(serverSessionClosed)
}() }()
@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := sess.AcceptStream() _, err := sess.AcceptStream(context.Background())
checkTimeoutError(err) checkTimeoutError(err)
close(done) close(done)
}() }()

View file

@ -57,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()

View file

@ -127,11 +127,11 @@ type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.
// If the session was closed due to a timeout, the error satisfies // If the session was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true. // the net.Error interface, and Timeout() will be true.
AcceptStream() (Stream, error) AcceptStream(context.Context) (Stream, error)
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// If the session was closed due to a timeout, the error satisfies // If the session was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true. // the net.Error interface, and Timeout() will be true.
AcceptUniStream() (ReceiveStream, error) AcceptUniStream(context.Context) (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream. // OpenStream opens a new bidirectional QUIC stream.
// There is no signaling to the peer about new streams: // There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream. // The peer can only accept the stream after data has been sent on the stream.

View file

@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder {
} }
// AcceptStream mocks base method // AcceptStream mocks base method
func (m *MockSession) AcceptStream() (quic_go.Stream, error) { func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream") ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(quic_go.Stream) ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AcceptStream indicates an expected call of AcceptStream // AcceptStream indicates an expected call of AcceptStream
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call { func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0)
} }
// AcceptUniStream mocks base method // AcceptUniStream mocks base method
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) { func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream") ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(quic_go.ReceiveStream) ret0, _ := ret[0].(quic_go.ReceiveStream)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AcceptUniStream indicates an expected call of AcceptUniStream // AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call { func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0)
} }
// Close mocks base method // Close mocks base method

View file

@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder {
} }
// AcceptStream mocks base method // AcceptStream mocks base method
func (m *MockQuicSession) AcceptStream() (Stream, error) { func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream") ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(Stream) ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AcceptStream indicates an expected call of AcceptStream // AcceptStream indicates an expected call of AcceptStream
func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call { func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0)
} }
// AcceptUniStream mocks base method // AcceptUniStream mocks base method
func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) { func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream") ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(ReceiveStream) ret0, _ := ret[0].(ReceiveStream)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AcceptUniStream indicates an expected call of AcceptUniStream // AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call { func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0)
} }
// Close mocks base method // Close mocks base method

View file

@ -5,6 +5,7 @@
package quic package quic
import ( import (
context "context"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
} }
// AcceptStream mocks base method // AcceptStream mocks base method
func (m *MockStreamManager) AcceptStream() (Stream, error) { func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream") ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(Stream) ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AcceptStream indicates an expected call of AcceptStream // AcceptStream indicates an expected call of AcceptStream
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call { func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
} }
// AcceptUniStream mocks base method // AcceptUniStream mocks base method
func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) { func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream") ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(ReceiveStream) ret0, _ := ret[0].(ReceiveStream)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AcceptUniStream indicates an expected call of AcceptUniStream // AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call { func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
} }
// CloseWithError mocks base method // CloseWithError mocks base method

View file

@ -39,8 +39,8 @@ type streamManager interface {
OpenUniStream() (SendStream, error) OpenUniStream() (SendStream, error)
OpenStreamSync() (Stream, error) OpenStreamSync() (Stream, error)
OpenUniStreamSync() (SendStream, error) OpenUniStreamSync() (SendStream, error)
AcceptStream() (Stream, error) AcceptStream(context.Context) (Stream, error)
AcceptUniStream() (ReceiveStream, error) AcceptUniStream(context.Context) (ReceiveStream, error)
DeleteStream(protocol.StreamID) error DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters) error UpdateLimits(*handshake.TransportParameters) error
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) {
} }
// AcceptStream returns the next stream openend by the peer // AcceptStream returns the next stream openend by the peer
func (s *session) AcceptStream() (Stream, error) { func (s *session) AcceptStream(ctx context.Context) (Stream, error) {
return s.streamsMap.AcceptStream() return s.streamsMap.AcceptStream(ctx)
} }
func (s *session) AcceptUniStream() (ReceiveStream, error) { func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
return s.streamsMap.AcceptUniStream() return s.streamsMap.AcceptUniStream(ctx)
} }
// OpenStream opens a stream // OpenStream opens a stream

View file

@ -367,14 +367,6 @@ var _ = Describe("Session", func() {
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
}) })
It("accepts new streams", func() {
mstr := NewMockStreamI(mockCtrl)
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})
Context("closing", func() { Context("closing", func() {
var ( var (
runErr error runErr error
@ -1454,17 +1446,21 @@ var _ = Describe("Session", func() {
}) })
It("accepts streams", func() { It("accepts streams", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
mstr := NewMockStreamI(mockCtrl) mstr := NewMockStreamI(mockCtrl)
streamManager.EXPECT().AcceptStream().Return(mstr, nil) streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil)
str, err := sess.AcceptStream() str, err := sess.AcceptStream(ctx)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr)) Expect(str).To(Equal(mstr))
}) })
It("accepts unidirectional streams", func() { It("accepts unidirectional streams", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
mstr := NewMockReceiveStreamI(mockCtrl) mstr := NewMockReceiveStreamI(mockCtrl)
streamManager.EXPECT().AcceptUniStream().Return(mstr, nil) streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
str, err := sess.AcceptUniStream() str, err := sess.AcceptUniStream(ctx)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr)) Expect(str).To(Equal(mstr))
}) })

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -123,13 +124,13 @@ func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
} }
func (m *streamsMap) AcceptStream() (Stream, error) { func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
str, err := m.incomingBidiStreams.AcceptStream() str, err := m.incomingBidiStreams.AcceptStream(ctx)
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
} }
func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
str, err := m.incomingUniStreams.AcceptStream() str, err := m.incomingUniStreams.AcceptStream(ctx)
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
} }

View file

@ -5,6 +5,7 @@
package quic package quic
import ( import (
"context"
"sync" "sync"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -50,24 +51,28 @@ func newIncomingBidiStreamsMap(
} }
} }
func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
var num protocol.StreamNum var num protocol.StreamNum
var str streamI var str streamI
for { for {
num = m.nextStreamToAccept num = m.nextStreamToAccept
var ok bool
if m.closeErr != nil { if m.closeErr != nil {
m.mutex.Unlock()
return nil, m.closeErr return nil, m.closeErr
} }
var ok bool
str, ok = m.streams[num] str, ok = m.streams[num]
if ok { if ok {
break break
} }
m.mutex.Unlock() m.mutex.Unlock()
<-m.newStreamChan select {
case <-ctx.Done():
return nil, ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock() m.mutex.Lock()
} }
m.nextStreamToAccept++ m.nextStreamToAccept++
@ -75,9 +80,11 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
if _, ok := m.streamsToDelete[num]; ok { if _, ok := m.streamsToDelete[num]; ok {
delete(m.streamsToDelete, num) delete(m.streamsToDelete, num)
if err := m.deleteStream(num); err != nil { if err := m.deleteStream(num); err != nil {
m.mutex.Unlock()
return nil, err return nil, err
} }
} }
m.mutex.Unlock()
return str, nil return str, nil
} }

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"sync" "sync"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -48,24 +49,28 @@ func newIncomingItemsMap(
} }
} }
func (m *incomingItemsMap) AcceptStream() (item, error) { func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
var num protocol.StreamNum var num protocol.StreamNum
var str item var str item
for { for {
num = m.nextStreamToAccept num = m.nextStreamToAccept
var ok bool
if m.closeErr != nil { if m.closeErr != nil {
m.mutex.Unlock()
return nil, m.closeErr return nil, m.closeErr
} }
var ok bool
str, ok = m.streams[num] str, ok = m.streams[num]
if ok { if ok {
break break
} }
m.mutex.Unlock() m.mutex.Unlock()
<-m.newStreamChan select {
case <-ctx.Done():
return nil, ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock() m.mutex.Lock()
} }
m.nextStreamToAccept++ m.nextStreamToAccept++
@ -73,9 +78,11 @@ func (m *incomingItemsMap) AcceptStream() (item, error) {
if _, ok := m.streamsToDelete[num]; ok { if _, ok := m.streamsToDelete[num]; ok {
delete(m.streamsToDelete, num) delete(m.streamsToDelete, num)
if err := m.deleteStream(num); err != nil { if err := m.deleteStream(num); err != nil {
m.mutex.Unlock()
return nil, err return nil, err
} }
} }
m.mutex.Unlock()
return str, nil return str, nil
} }

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"errors" "errors"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() {
It("accepts streams in the right order", func() { It("accepts streams in the right order", func() {
_, err := m.GetOrOpenStream(2) // open streams 1 and 2 _, err := m.GetOrOpenStream(2) // open streams 1 and 2
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream() str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
str, err = m.AcceptStream() str, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
}) })
@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() {
strChan := make(chan item) strChan := make(chan item)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
str, err := m.AcceptStream() str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
strChan <- str strChan <- str
}() }()
@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
}) })
It("unblocks AcceptStream when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.AcceptStream(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("unblocks AcceptStream when it is closed", func() { It("unblocks AcceptStream when it is closed", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := m.AcceptStream() _, err := m.AcceptStream(context.Background())
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() {
It("errors AcceptStream immediately if it is closed", func() { It("errors AcceptStream immediately if it is closed", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
m.CloseWithError(testErr) m.CloseWithError(testErr)
_, err := m.AcceptStream() _, err := m.AcceptStream(context.Background())
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().queueControlFrame(gomock.Any())
_, err := m.GetOrOpenStream(1) _, err := m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream() str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
Expect(m.DeleteStream(1)).To(Succeed()) Expect(m.DeleteStream(1)).To(Succeed())
@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() {
_, err := m.GetOrOpenStream(2) _, err := m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(m.DeleteStream(2)).To(Succeed()) Expect(m.DeleteStream(2)).To(Succeed())
str, err := m.AcceptStream() str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().queueControlFrame(gomock.Any())
str, err = m.AcceptStream() str, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
}) })
@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(str).To(BeNil()) Expect(str).To(BeNil())
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().queueControlFrame(gomock.Any())
str, err = m.AcceptStream() str, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil()) Expect(str).ToNot(BeNil())
}) })
@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// accept all streams // accept all streams
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
_, err := m.AcceptStream() _, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {

View file

@ -5,6 +5,7 @@
package quic package quic
import ( import (
"context"
"sync" "sync"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -50,24 +51,28 @@ func newIncomingUniStreamsMap(
} }
} }
func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
var num protocol.StreamNum var num protocol.StreamNum
var str receiveStreamI var str receiveStreamI
for { for {
num = m.nextStreamToAccept num = m.nextStreamToAccept
var ok bool
if m.closeErr != nil { if m.closeErr != nil {
m.mutex.Unlock()
return nil, m.closeErr return nil, m.closeErr
} }
var ok bool
str, ok = m.streams[num] str, ok = m.streams[num]
if ok { if ok {
break break
} }
m.mutex.Unlock() m.mutex.Unlock()
<-m.newStreamChan select {
case <-ctx.Done():
return nil, ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock() m.mutex.Lock()
} }
m.nextStreamToAccept++ m.nextStreamToAccept++
@ -75,9 +80,11 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
if _, ok := m.streamsToDelete[num]; ok { if _, ok := m.streamsToDelete[num]; ok {
delete(m.streamsToDelete, num) delete(m.streamsToDelete, num)
if err := m.deleteStream(num); err != nil { if err := m.deleteStream(num); err != nil {
m.mutex.Unlock()
return nil, err return nil, err
} }
} }
m.mutex.Unlock()
return str, nil return str, nil
} }

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() {
It("accepts bidirectional streams", func() { It("accepts bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream() str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeAssignableToTypeOf(&stream{})) Expect(str).To(BeAssignableToTypeOf(&stream{}))
Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream)) Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() {
It("accepts unidirectional streams", func() { It("accepts unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptUniStream() str, err := m.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeAssignableToTypeOf(&receiveStream{})) Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream)) Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() {
_, err := m.GetOrOpenReceiveStream(id) _, err := m.GetOrOpenReceiveStream(id)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(m.DeleteStream(id)).To(Succeed()) Expect(m.DeleteStream(id)).To(Succeed())
str, err := m.AcceptStream() str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil()) Expect(str).ToNot(BeNil())
Expect(str.StreamID()).To(Equal(id)) Expect(str.StreamID()).To(Equal(id))
@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() {
_, err := m.GetOrOpenReceiveStream(id) _, err := m.GetOrOpenReceiveStream(id)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(m.DeleteStream(id)).To(Succeed()) Expect(m.DeleteStream(id)).To(Succeed())
str, err := m.AcceptUniStream() str, err := m.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil()) Expect(str).ToNot(BeNil())
Expect(str.StreamID()).To(Equal(id)) Expect(str.StreamID()).To(Equal(id))
@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() {
It("sends a MAX_STREAMS frame for bidirectional streams", func() { It("sends a MAX_STREAMS frame for bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = m.AcceptStream() _, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi, Type: protocol.StreamTypeBidi,
@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() {
It("sends a MAX_STREAMS frame for unidirectional streams", func() { It("sends a MAX_STREAMS frame for unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = m.AcceptUniStream() _, err = m.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni, Type: protocol.StreamTypeUni,
@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() {
_, err = m.OpenUniStream() _, err = m.OpenUniStream()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(testErr.Error())) Expect(err.Error()).To(Equal(testErr.Error()))
_, err = m.AcceptStream() _, err = m.AcceptStream(context.Background())
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(testErr.Error())) Expect(err.Error()).To(Equal(testErr.Error()))
_, err = m.AcceptUniStream() _, err = m.AcceptUniStream(context.Background())
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(testErr.Error())) Expect(err.Error()).To(Equal(testErr.Error()))
}) })