mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
add a context to Session.Accept{Uni}Stream
This commit is contained in:
parent
f74082b2fb
commit
5550ba2c3b
28 changed files with 140 additions and 105 deletions
|
@ -8,7 +8,7 @@
|
|||
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
|
||||
- Use a varint for error codes.
|
||||
- Add support for [quic-trace](https://github.com/google/quic-trace).
|
||||
- Add a context to `Listener.Accept`.
|
||||
- Add a context to `Listener.Accept` and `Session.Accept{Uni}Stream`.
|
||||
|
||||
## v0.11.0 (2019-04-05)
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ func init() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(handshakeChan)
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
|
|
@ -40,7 +40,7 @@ func echoServer() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream, err := sess.AcceptStream()
|
||||
stream, err := sess.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -138,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) {
|
|||
str.Write(buf.Bytes())
|
||||
|
||||
for {
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
s.logger.Debugf("Accepting stream failed: %s", err)
|
||||
return
|
||||
|
|
|
@ -76,7 +76,7 @@ var _ = Describe("Stream Cancelations", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel around 2/3 of the streams
|
||||
if rand.Int31()%3 != 0 {
|
||||
|
@ -120,7 +120,7 @@ var _ = Describe("Stream Cancelations", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// only read some data from about 1/3 of the streams
|
||||
if rand.Int31()%3 != 0 {
|
||||
|
@ -168,7 +168,7 @@ var _ = Describe("Stream Cancelations", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
if err != nil {
|
||||
|
@ -304,7 +304,7 @@ var _ = Describe("Stream Cancelations", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel around half of the streams
|
||||
if rand.Int31()%2 == 0 {
|
||||
|
@ -383,7 +383,7 @@ var _ = Describe("Stream Cancelations", func() {
|
|||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
r := io.Reader(str)
|
||||
|
|
|
@ -52,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cl.Close()
|
||||
str, err := cl.AcceptStream()
|
||||
str, err := cl.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -30,7 +30,7 @@ var _ = Describe("Stream deadline tests", func() {
|
|||
defer GinkgoRecover()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverStr, err = sess.AcceptStream()
|
||||
serverStr, err = sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = serverStr.Read([]byte{0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -110,7 +110,7 @@ var _ = Describe("Drop Tests", func() {
|
|||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer sess.Close()
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := uint8(1); i <= numMessages; i++ {
|
||||
b := []byte{0}
|
||||
|
|
|
@ -61,7 +61,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer sess.Close()
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||
|
@ -107,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() {
|
|||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||
|
|
|
@ -159,7 +159,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
errChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := sess.AcceptStream()
|
||||
_, err := sess.AcceptStream(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
|
|
|
@ -51,7 +51,7 @@ var _ = Describe("Multiplexing", func() {
|
|||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -68,7 +68,7 @@ var _ = Describe("non-zero RTT", func() {
|
|||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -64,7 +64,7 @@ var _ = Describe("Stateless Resets", func() {
|
|||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := make([]byte, 6)
|
||||
_, err = str.Read(data)
|
||||
|
|
|
@ -71,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() {
|
|||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
|
|
@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
&quic.Config{IdleTimeout: idleTimeout},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strIn, err := sess.AcceptStream()
|
||||
strIn, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strOut, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() {
|
|||
checkTimeoutError(err)
|
||||
_, err = sess.OpenUniStream()
|
||||
checkTimeoutError(err)
|
||||
_, err = sess.AcceptStream()
|
||||
_, err = sess.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
_, err = sess.AcceptUniStream()
|
||||
_, err = sess.AcceptUniStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
})
|
||||
|
||||
|
@ -148,7 +148,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
defer GinkgoRecover()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.AcceptStream() // blocks until the session is closed
|
||||
sess.AcceptStream(context.Background()) // blocks until the session is closed
|
||||
close(serverSessionClosed)
|
||||
}()
|
||||
|
||||
|
@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := sess.AcceptStream()
|
||||
_, err := sess.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
close(done)
|
||||
}()
|
||||
|
@ -189,7 +189,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
defer GinkgoRecover()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.AcceptStream() // blocks until the session is closed
|
||||
sess.AcceptStream(context.Background()) // blocks until the session is closed
|
||||
close(serverSessionClosed)
|
||||
}()
|
||||
|
||||
|
@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := sess.AcceptStream()
|
||||
_, err := sess.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
close(done)
|
||||
}()
|
||||
|
|
|
@ -57,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
|||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
|
|
@ -127,11 +127,11 @@ type Session interface {
|
|||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
// If the session was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptStream() (Stream, error)
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
// If the session was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
|
|
|
@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder {
|
|||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockSession) AcceptStream() (quic_go.Stream, error) {
|
||||
func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(quic_go.Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call {
|
||||
func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0)
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method
|
||||
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) {
|
||||
func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(quic_go.ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
||||
func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
|
|
|
@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder {
|
|||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockQuicSession) AcceptStream() (Stream, error) {
|
||||
func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call {
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0)
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method
|
||||
func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) {
|
||||
func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
|
|||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockStreamManager) AcceptStream() (Stream, error) {
|
||||
func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method
|
||||
func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) {
|
||||
func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call {
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
|
||||
}
|
||||
|
||||
// CloseWithError mocks base method
|
||||
|
|
12
session.go
12
session.go
|
@ -39,8 +39,8 @@ type streamManager interface {
|
|||
OpenUniStream() (SendStream, error)
|
||||
OpenStreamSync() (Stream, error)
|
||||
OpenUniStreamSync() (SendStream, error)
|
||||
AcceptStream() (Stream, error)
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||
DeleteStream(protocol.StreamID) error
|
||||
UpdateLimits(*handshake.TransportParameters) error
|
||||
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
|
||||
|
@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) {
|
|||
}
|
||||
|
||||
// AcceptStream returns the next stream openend by the peer
|
||||
func (s *session) AcceptStream() (Stream, error) {
|
||||
return s.streamsMap.AcceptStream()
|
||||
func (s *session) AcceptStream(ctx context.Context) (Stream, error) {
|
||||
return s.streamsMap.AcceptStream(ctx)
|
||||
}
|
||||
|
||||
func (s *session) AcceptUniStream() (ReceiveStream, error) {
|
||||
return s.streamsMap.AcceptUniStream()
|
||||
func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||
return s.streamsMap.AcceptUniStream(ctx)
|
||||
}
|
||||
|
||||
// OpenStream opens a stream
|
||||
|
|
|
@ -367,14 +367,6 @@ var _ = Describe("Session", func() {
|
|||
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
|
||||
})
|
||||
|
||||
It("accepts new streams", func() {
|
||||
mstr := NewMockStreamI(mockCtrl)
|
||||
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
var (
|
||||
runErr error
|
||||
|
@ -1454,17 +1446,21 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
|
||||
It("accepts streams", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
mstr := NewMockStreamI(mockCtrl)
|
||||
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
|
||||
str, err := sess.AcceptStream()
|
||||
streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil)
|
||||
str, err := sess.AcceptStream(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
It("accepts unidirectional streams", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
mstr := NewMockReceiveStreamI(mockCtrl)
|
||||
streamManager.EXPECT().AcceptUniStream().Return(mstr, nil)
|
||||
str, err := sess.AcceptUniStream()
|
||||
streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
|
||||
str, err := sess.AcceptUniStream(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -123,13 +124,13 @@ func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
|
|||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) AcceptStream() (Stream, error) {
|
||||
str, err := m.incomingBidiStreams.AcceptStream()
|
||||
func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
|
||||
str, err := m.incomingBidiStreams.AcceptStream(ctx)
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
|
||||
}
|
||||
|
||||
func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
|
||||
str, err := m.incomingUniStreams.AcceptStream()
|
||||
func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||
str, err := m.incomingUniStreams.AcceptStream(ctx)
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -50,24 +51,28 @@ func newIncomingBidiStreamsMap(
|
|||
}
|
||||
}
|
||||
|
||||
func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
|
||||
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
var str streamI
|
||||
for {
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, m.closeErr
|
||||
}
|
||||
var ok bool
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
<-m.newStreamChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-m.newStreamChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
}
|
||||
m.nextStreamToAccept++
|
||||
|
@ -75,9 +80,11 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
|
|||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return str, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -48,24 +49,28 @@ func newIncomingItemsMap(
|
|||
}
|
||||
}
|
||||
|
||||
func (m *incomingItemsMap) AcceptStream() (item, error) {
|
||||
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
var str item
|
||||
for {
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, m.closeErr
|
||||
}
|
||||
var ok bool
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
<-m.newStreamChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-m.newStreamChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
}
|
||||
m.nextStreamToAccept++
|
||||
|
@ -73,9 +78,11 @@ func (m *incomingItemsMap) AcceptStream() (item, error) {
|
|||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return str, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
It("accepts streams in the right order", func() {
|
||||
_, err := m.GetOrOpenStream(2) // open streams 1 and 2
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
str, err = m.AcceptStream()
|
||||
str, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
|
@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
strChan := make(chan item)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strChan <- str
|
||||
}()
|
||||
|
@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
})
|
||||
|
||||
It("unblocks AcceptStream when the context is canceled", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.AcceptStream(ctx)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("unblocks AcceptStream when it is closed", func() {
|
||||
testErr := errors.New("test error")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.AcceptStream()
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
|
@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
It("errors AcceptStream immediately if it is closed", func() {
|
||||
testErr := errors.New("test error")
|
||||
m.CloseWithError(testErr)
|
||||
_, err := m.AcceptStream()
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
_, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
|
@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
_, err := m.GetOrOpenStream(2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(2)).To(Succeed())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str, err = m.AcceptStream()
|
||||
str, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
|
@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
Expect(str).To(BeNil())
|
||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str, err = m.AcceptStream()
|
||||
str, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
})
|
||||
|
@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
// accept all streams
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := m.AcceptStream()
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -50,24 +51,28 @@ func newIncomingUniStreamsMap(
|
|||
}
|
||||
}
|
||||
|
||||
func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
|
||||
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
var str receiveStreamI
|
||||
for {
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, m.closeErr
|
||||
}
|
||||
var ok bool
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
<-m.newStreamChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-m.newStreamChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
}
|
||||
m.nextStreamToAccept++
|
||||
|
@ -75,9 +80,11 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
|
|||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return str, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("accepts bidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeAssignableToTypeOf(&stream{}))
|
||||
Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
|
||||
|
@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("accepts unidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptUniStream()
|
||||
str, err := m.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
|
||||
Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
|
||||
|
@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() {
|
|||
_, err := m.GetOrOpenReceiveStream(id)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(id)).To(Succeed())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(str.StreamID()).To(Equal(id))
|
||||
|
@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() {
|
|||
_, err := m.GetOrOpenReceiveStream(id)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(id)).To(Succeed())
|
||||
str, err := m.AcceptUniStream()
|
||||
str, err := m.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(str.StreamID()).To(Equal(id))
|
||||
|
@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = m.AcceptStream()
|
||||
_, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeBidi,
|
||||
|
@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() {
|
|||
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = m.AcceptUniStream()
|
||||
_, err = m.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeUni,
|
||||
|
@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() {
|
|||
_, err = m.OpenUniStream()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||
_, err = m.AcceptStream()
|
||||
_, err = m.AcceptStream(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||
_, err = m.AcceptUniStream()
|
||||
_, err = m.AcceptUniStream(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue