mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
add a context to Session.Accept{Uni}Stream
This commit is contained in:
parent
f74082b2fb
commit
5550ba2c3b
28 changed files with 140 additions and 105 deletions
|
@ -8,7 +8,7 @@
|
||||||
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
|
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
|
||||||
- Use a varint for error codes.
|
- Use a varint for error codes.
|
||||||
- Add support for [quic-trace](https://github.com/google/quic-trace).
|
- Add support for [quic-trace](https://github.com/google/quic-trace).
|
||||||
- Add a context to `Listener.Accept`.
|
- Add a context to `Listener.Accept` and `Session.Accept{Uni}Stream`.
|
||||||
|
|
||||||
## v0.11.0 (2019-04-05)
|
## v0.11.0 (2019-04-05)
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ func init() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(handshakeChan)
|
close(handshakeChan)
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
|
@ -40,7 +40,7 @@ func echoServer() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stream, err := sess.AcceptStream()
|
stream, err := sess.AcceptStream(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) {
|
||||||
str.Write(buf.Bytes())
|
str.Write(buf.Bytes())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debugf("Accepting stream failed: %s", err)
|
s.logger.Debugf("Accepting stream failed: %s", err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -76,7 +76,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// cancel around 2/3 of the streams
|
// cancel around 2/3 of the streams
|
||||||
if rand.Int31()%3 != 0 {
|
if rand.Int31()%3 != 0 {
|
||||||
|
@ -120,7 +120,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// only read some data from about 1/3 of the streams
|
// only read some data from about 1/3 of the streams
|
||||||
if rand.Int31()%3 != 0 {
|
if rand.Int31()%3 != 0 {
|
||||||
|
@ -168,7 +168,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := ioutil.ReadAll(str)
|
data, err := ioutil.ReadAll(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -304,7 +304,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// cancel around half of the streams
|
// cancel around half of the streams
|
||||||
if rand.Int31()%2 == 0 {
|
if rand.Int31()%2 == 0 {
|
||||||
|
@ -383,7 +383,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
r := io.Reader(str)
|
r := io.Reader(str)
|
||||||
|
|
|
@ -52,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer cl.Close()
|
defer cl.Close()
|
||||||
str, err := cl.AcceptStream()
|
str, err := cl.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := ioutil.ReadAll(str)
|
data, err := ioutil.ReadAll(str)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -30,7 +30,7 @@ var _ = Describe("Stream deadline tests", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
sess, err := server.Accept(context.Background())
|
sess, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
serverStr, err = sess.AcceptStream()
|
serverStr, err = sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = serverStr.Read([]byte{0})
|
_, err = serverStr.Read([]byte{0})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -110,7 +110,7 @@ var _ = Describe("Drop Tests", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer sess.Close()
|
defer sess.Close()
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
for i := uint8(1); i <= numMessages; i++ {
|
for i := uint8(1); i <= numMessages; i++ {
|
||||||
b := []byte{0}
|
b := []byte{0}
|
||||||
|
|
|
@ -61,7 +61,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||||
sess, err := ln.Accept(context.Background())
|
sess, err := ln.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer sess.Close()
|
defer sess.Close()
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
b := make([]byte, 6)
|
b := make([]byte, 6)
|
||||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||||
|
@ -107,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
b := make([]byte, 6)
|
b := make([]byte, 6)
|
||||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||||
|
|
|
@ -159,7 +159,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := sess.AcceptStream()
|
_, err := sess.AcceptStream(context.Background())
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}()
|
}()
|
||||||
Eventually(errChan).Should(Receive(&err))
|
Eventually(errChan).Should(Receive(&err))
|
||||||
|
|
|
@ -51,7 +51,7 @@ var _ = Describe("Multiplexing", func() {
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := ioutil.ReadAll(str)
|
data, err := ioutil.ReadAll(str)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -68,7 +68,7 @@ var _ = Describe("non-zero RTT", func() {
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data, err := ioutil.ReadAll(str)
|
data, err := ioutil.ReadAll(str)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -64,7 +64,7 @@ var _ = Describe("Stateless Resets", func() {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
data := make([]byte, 6)
|
data := make([]byte, 6)
|
||||||
_, err = str.Read(data)
|
_, err = str.Read(data)
|
||||||
|
|
|
@ -71,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(numStreams)
|
wg.Add(numStreams)
|
||||||
for i := 0; i < numStreams; i++ {
|
for i := 0; i < numStreams; i++ {
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
|
|
@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
&quic.Config{IdleTimeout: idleTimeout},
|
&quic.Config{IdleTimeout: idleTimeout},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
strIn, err := sess.AcceptStream()
|
strIn, err := sess.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
strOut, err := sess.OpenStream()
|
strOut, err := sess.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() {
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
_, err = sess.OpenUniStream()
|
_, err = sess.OpenUniStream()
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
_, err = sess.AcceptStream()
|
_, err = sess.AcceptStream(context.Background())
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
_, err = sess.AcceptUniStream()
|
_, err = sess.AcceptUniStream(context.Background())
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -148,7 +148,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
sess, err := server.Accept(context.Background())
|
sess, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
sess.AcceptStream() // blocks until the session is closed
|
sess.AcceptStream(context.Background()) // blocks until the session is closed
|
||||||
close(serverSessionClosed)
|
close(serverSessionClosed)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := sess.AcceptStream()
|
_, err := sess.AcceptStream(context.Background())
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
@ -189,7 +189,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
sess, err := server.Accept(context.Background())
|
sess, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
sess.AcceptStream() // blocks until the session is closed
|
sess.AcceptStream(context.Background()) // blocks until the session is closed
|
||||||
close(serverSessionClosed)
|
close(serverSessionClosed)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := sess.AcceptStream()
|
_, err := sess.AcceptStream(context.Background())
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -57,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(numStreams)
|
wg.Add(numStreams)
|
||||||
for i := 0; i < numStreams; i++ {
|
for i := 0; i < numStreams; i++ {
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
|
|
@ -127,11 +127,11 @@ type Session interface {
|
||||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||||
// If the session was closed due to a timeout, the error satisfies
|
// If the session was closed due to a timeout, the error satisfies
|
||||||
// the net.Error interface, and Timeout() will be true.
|
// the net.Error interface, and Timeout() will be true.
|
||||||
AcceptStream() (Stream, error)
|
AcceptStream(context.Context) (Stream, error)
|
||||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||||
// If the session was closed due to a timeout, the error satisfies
|
// If the session was closed due to a timeout, the error satisfies
|
||||||
// the net.Error interface, and Timeout() will be true.
|
// the net.Error interface, and Timeout() will be true.
|
||||||
AcceptUniStream() (ReceiveStream, error)
|
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||||
// OpenStream opens a new bidirectional QUIC stream.
|
// OpenStream opens a new bidirectional QUIC stream.
|
||||||
// There is no signaling to the peer about new streams:
|
// There is no signaling to the peer about new streams:
|
||||||
// The peer can only accept the stream after data has been sent on the stream.
|
// The peer can only accept the stream after data has been sent on the stream.
|
||||||
|
|
|
@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream mocks base method
|
// AcceptStream mocks base method
|
||||||
func (m *MockSession) AcceptStream() (quic_go.Stream, error) {
|
func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AcceptStream")
|
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||||
ret0, _ := ret[0].(quic_go.Stream)
|
ret0, _ := ret[0].(quic_go.Stream)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream indicates an expected call of AcceptStream
|
// AcceptStream indicates an expected call of AcceptStream
|
||||||
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call {
|
func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptUniStream mocks base method
|
// AcceptUniStream mocks base method
|
||||||
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) {
|
func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||||
ret0, _ := ret[0].(quic_go.ReceiveStream)
|
ret0, _ := ret[0].(quic_go.ReceiveStream)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||||
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close mocks base method
|
// Close mocks base method
|
||||||
|
|
|
@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream mocks base method
|
// AcceptStream mocks base method
|
||||||
func (m *MockQuicSession) AcceptStream() (Stream, error) {
|
func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AcceptStream")
|
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||||
ret0, _ := ret[0].(Stream)
|
ret0, _ := ret[0].(Stream)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream indicates an expected call of AcceptStream
|
// AcceptStream indicates an expected call of AcceptStream
|
||||||
func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call {
|
func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptUniStream mocks base method
|
// AcceptUniStream mocks base method
|
||||||
func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) {
|
func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||||
ret0, _ := ret[0].(ReceiveStream)
|
ret0, _ := ret[0].(ReceiveStream)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||||
func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close mocks base method
|
// Close mocks base method
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context "context"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream mocks base method
|
// AcceptStream mocks base method
|
||||||
func (m *MockStreamManager) AcceptStream() (Stream, error) {
|
func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AcceptStream")
|
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||||
ret0, _ := ret[0].(Stream)
|
ret0, _ := ret[0].(Stream)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream indicates an expected call of AcceptStream
|
// AcceptStream indicates an expected call of AcceptStream
|
||||||
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
|
func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptUniStream mocks base method
|
// AcceptUniStream mocks base method
|
||||||
func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) {
|
func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||||
ret0, _ := ret[0].(ReceiveStream)
|
ret0, _ := ret[0].(ReceiveStream)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||||
func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call {
|
func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseWithError mocks base method
|
// CloseWithError mocks base method
|
||||||
|
|
12
session.go
12
session.go
|
@ -39,8 +39,8 @@ type streamManager interface {
|
||||||
OpenUniStream() (SendStream, error)
|
OpenUniStream() (SendStream, error)
|
||||||
OpenStreamSync() (Stream, error)
|
OpenStreamSync() (Stream, error)
|
||||||
OpenUniStreamSync() (SendStream, error)
|
OpenUniStreamSync() (SendStream, error)
|
||||||
AcceptStream() (Stream, error)
|
AcceptStream(context.Context) (Stream, error)
|
||||||
AcceptUniStream() (ReceiveStream, error)
|
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||||
DeleteStream(protocol.StreamID) error
|
DeleteStream(protocol.StreamID) error
|
||||||
UpdateLimits(*handshake.TransportParameters) error
|
UpdateLimits(*handshake.TransportParameters) error
|
||||||
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
|
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
|
||||||
|
@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptStream returns the next stream openend by the peer
|
// AcceptStream returns the next stream openend by the peer
|
||||||
func (s *session) AcceptStream() (Stream, error) {
|
func (s *session) AcceptStream(ctx context.Context) (Stream, error) {
|
||||||
return s.streamsMap.AcceptStream()
|
return s.streamsMap.AcceptStream(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) AcceptUniStream() (ReceiveStream, error) {
|
func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||||
return s.streamsMap.AcceptUniStream()
|
return s.streamsMap.AcceptUniStream(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenStream opens a stream
|
// OpenStream opens a stream
|
||||||
|
|
|
@ -367,14 +367,6 @@ var _ = Describe("Session", func() {
|
||||||
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
|
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts new streams", func() {
|
|
||||||
mstr := NewMockStreamI(mockCtrl)
|
|
||||||
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str).To(Equal(mstr))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("closing", func() {
|
Context("closing", func() {
|
||||||
var (
|
var (
|
||||||
runErr error
|
runErr error
|
||||||
|
@ -1454,17 +1446,21 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts streams", func() {
|
It("accepts streams", func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
mstr := NewMockStreamI(mockCtrl)
|
mstr := NewMockStreamI(mockCtrl)
|
||||||
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
|
streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil)
|
||||||
str, err := sess.AcceptStream()
|
str, err := sess.AcceptStream(ctx)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(Equal(mstr))
|
Expect(str).To(Equal(mstr))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("accepts unidirectional streams", func() {
|
It("accepts unidirectional streams", func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
mstr := NewMockReceiveStreamI(mockCtrl)
|
mstr := NewMockReceiveStreamI(mockCtrl)
|
||||||
streamManager.EXPECT().AcceptUniStream().Return(mstr, nil)
|
streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
|
||||||
str, err := sess.AcceptUniStream()
|
str, err := sess.AcceptUniStream(ctx)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(Equal(mstr))
|
Expect(str).To(Equal(mstr))
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -123,13 +124,13 @@ func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
|
||||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) AcceptStream() (Stream, error) {
|
func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
|
||||||
str, err := m.incomingBidiStreams.AcceptStream()
|
str, err := m.incomingBidiStreams.AcceptStream(ctx)
|
||||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
|
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
|
func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||||
str, err := m.incomingUniStreams.AcceptStream()
|
str, err := m.incomingUniStreams.AcceptStream(ctx)
|
||||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
|
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
@ -50,24 +51,28 @@ func newIncomingBidiStreamsMap(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
|
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
var num protocol.StreamNum
|
var num protocol.StreamNum
|
||||||
var str streamI
|
var str streamI
|
||||||
for {
|
for {
|
||||||
num = m.nextStreamToAccept
|
num = m.nextStreamToAccept
|
||||||
var ok bool
|
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
return nil, m.closeErr
|
return nil, m.closeErr
|
||||||
}
|
}
|
||||||
|
var ok bool
|
||||||
str, ok = m.streams[num]
|
str, ok = m.streams[num]
|
||||||
if ok {
|
if ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
<-m.newStreamChan
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-m.newStreamChan:
|
||||||
|
}
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
}
|
}
|
||||||
m.nextStreamToAccept++
|
m.nextStreamToAccept++
|
||||||
|
@ -75,9 +80,11 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
|
||||||
if _, ok := m.streamsToDelete[num]; ok {
|
if _, ok := m.streamsToDelete[num]; ok {
|
||||||
delete(m.streamsToDelete, num)
|
delete(m.streamsToDelete, num)
|
||||||
if err := m.deleteStream(num); err != nil {
|
if err := m.deleteStream(num); err != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
@ -48,24 +49,28 @@ func newIncomingItemsMap(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *incomingItemsMap) AcceptStream() (item, error) {
|
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
var num protocol.StreamNum
|
var num protocol.StreamNum
|
||||||
var str item
|
var str item
|
||||||
for {
|
for {
|
||||||
num = m.nextStreamToAccept
|
num = m.nextStreamToAccept
|
||||||
var ok bool
|
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
return nil, m.closeErr
|
return nil, m.closeErr
|
||||||
}
|
}
|
||||||
|
var ok bool
|
||||||
str, ok = m.streams[num]
|
str, ok = m.streams[num]
|
||||||
if ok {
|
if ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
<-m.newStreamChan
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-m.newStreamChan:
|
||||||
|
}
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
}
|
}
|
||||||
m.nextStreamToAccept++
|
m.nextStreamToAccept++
|
||||||
|
@ -73,9 +78,11 @@ func (m *incomingItemsMap) AcceptStream() (item, error) {
|
||||||
if _, ok := m.streamsToDelete[num]; ok {
|
if _, ok := m.streamsToDelete[num]; ok {
|
||||||
delete(m.streamsToDelete, num)
|
delete(m.streamsToDelete, num)
|
||||||
if err := m.deleteStream(num); err != nil {
|
if err := m.deleteStream(num); err != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
It("accepts streams in the right order", func() {
|
It("accepts streams in the right order", func() {
|
||||||
_, err := m.GetOrOpenStream(2) // open streams 1 and 2
|
_, err := m.GetOrOpenStream(2) // open streams 1 and 2
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := m.AcceptStream()
|
str, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||||
str, err = m.AcceptStream()
|
str, err = m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||||
})
|
})
|
||||||
|
@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
strChan := make(chan item)
|
strChan := make(chan item)
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
str, err := m.AcceptStream()
|
str, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
strChan <- str
|
strChan <- str
|
||||||
}()
|
}()
|
||||||
|
@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("unblocks AcceptStream when the context is canceled", func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := m.AcceptStream(ctx)
|
||||||
|
Expect(err).To(MatchError("context canceled"))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
cancel()
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
It("unblocks AcceptStream when it is closed", func() {
|
It("unblocks AcceptStream when it is closed", func() {
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := m.AcceptStream()
|
_, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
It("errors AcceptStream immediately if it is closed", func() {
|
It("errors AcceptStream immediately if it is closed", func() {
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
m.CloseWithError(testErr)
|
m.CloseWithError(testErr)
|
||||||
_, err := m.AcceptStream()
|
_, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||||
_, err := m.GetOrOpenStream(1)
|
_, err := m.GetOrOpenStream(1)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := m.AcceptStream()
|
str, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||||
Expect(m.DeleteStream(1)).To(Succeed())
|
Expect(m.DeleteStream(1)).To(Succeed())
|
||||||
|
@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
_, err := m.GetOrOpenStream(2)
|
_, err := m.GetOrOpenStream(2)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(m.DeleteStream(2)).To(Succeed())
|
Expect(m.DeleteStream(2)).To(Succeed())
|
||||||
str, err := m.AcceptStream()
|
str, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||||
str, err = m.AcceptStream()
|
str, err = m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||||
})
|
})
|
||||||
|
@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
Expect(str).To(BeNil())
|
Expect(str).To(BeNil())
|
||||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||||
str, err = m.AcceptStream()
|
str, err = m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).ToNot(BeNil())
|
Expect(str).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
|
@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// accept all streams
|
// accept all streams
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
_, err := m.AcceptStream()
|
_, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
@ -50,24 +51,28 @@ func newIncomingUniStreamsMap(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
|
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
var num protocol.StreamNum
|
var num protocol.StreamNum
|
||||||
var str receiveStreamI
|
var str receiveStreamI
|
||||||
for {
|
for {
|
||||||
num = m.nextStreamToAccept
|
num = m.nextStreamToAccept
|
||||||
var ok bool
|
|
||||||
if m.closeErr != nil {
|
if m.closeErr != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
return nil, m.closeErr
|
return nil, m.closeErr
|
||||||
}
|
}
|
||||||
|
var ok bool
|
||||||
str, ok = m.streams[num]
|
str, ok = m.streams[num]
|
||||||
if ok {
|
if ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
<-m.newStreamChan
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-m.newStreamChan:
|
||||||
|
}
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
}
|
}
|
||||||
m.nextStreamToAccept++
|
m.nextStreamToAccept++
|
||||||
|
@ -75,9 +80,11 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
|
||||||
if _, ok := m.streamsToDelete[num]; ok {
|
if _, ok := m.streamsToDelete[num]; ok {
|
||||||
delete(m.streamsToDelete, num)
|
delete(m.streamsToDelete, num)
|
||||||
if err := m.deleteStream(num); err != nil {
|
if err := m.deleteStream(num); err != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
It("accepts bidirectional streams", func() {
|
It("accepts bidirectional streams", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := m.AcceptStream()
|
str, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(BeAssignableToTypeOf(&stream{}))
|
Expect(str).To(BeAssignableToTypeOf(&stream{}))
|
||||||
Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
|
Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
|
||||||
|
@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
It("accepts unidirectional streams", func() {
|
It("accepts unidirectional streams", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := m.AcceptUniStream()
|
str, err := m.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
|
Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
|
||||||
Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
|
Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
|
||||||
|
@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(id)
|
_, err := m.GetOrOpenReceiveStream(id)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(m.DeleteStream(id)).To(Succeed())
|
Expect(m.DeleteStream(id)).To(Succeed())
|
||||||
str, err := m.AcceptStream()
|
str, err := m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).ToNot(BeNil())
|
Expect(str).ToNot(BeNil())
|
||||||
Expect(str.StreamID()).To(Equal(id))
|
Expect(str.StreamID()).To(Equal(id))
|
||||||
|
@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(id)
|
_, err := m.GetOrOpenReceiveStream(id)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(m.DeleteStream(id)).To(Succeed())
|
Expect(m.DeleteStream(id)).To(Succeed())
|
||||||
str, err := m.AcceptUniStream()
|
str, err := m.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str).ToNot(BeNil())
|
Expect(str).ToNot(BeNil())
|
||||||
Expect(str.StreamID()).To(Equal(id))
|
Expect(str.StreamID()).To(Equal(id))
|
||||||
|
@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
|
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = m.AcceptStream()
|
_, err = m.AcceptStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||||
Type: protocol.StreamTypeBidi,
|
Type: protocol.StreamTypeBidi,
|
||||||
|
@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
|
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
|
||||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = m.AcceptUniStream()
|
_, err = m.AcceptUniStream(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||||
Type: protocol.StreamTypeUni,
|
Type: protocol.StreamTypeUni,
|
||||||
|
@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() {
|
||||||
_, err = m.OpenUniStream()
|
_, err = m.OpenUniStream()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||||
_, err = m.AcceptStream()
|
_, err = m.AcceptStream(context.Background())
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||||
_, err = m.AcceptUniStream()
|
_, err = m.AcceptUniStream(context.Background())
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue