diff --git a/Changelog.md b/Changelog.md index f365d1f6..49e9d4ad 100644 --- a/Changelog.md +++ b/Changelog.md @@ -8,7 +8,7 @@ - Enforce application protocol negotiation (via `tls.Config.NextProtos`). - Use a varint for error codes. - Add support for [quic-trace](https://github.com/google/quic-trace). -- Add a context to `Listener.Accept`. +- Add a context to `Listener.Accept` and `Session.Accept{Uni}Stream`. ## v0.11.0 (2019-04-05) diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index 846fb84c..77e533da 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -67,7 +67,7 @@ func init() { ) Expect(err).ToNot(HaveOccurred()) close(handshakeChan) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) buf := &bytes.Buffer{} diff --git a/example/echo/echo.go b/example/echo/echo.go index 77c6522f..1ee6c430 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -40,7 +40,7 @@ func echoServer() error { if err != nil { return err } - stream, err := sess.AcceptStream() + stream, err := sess.AcceptStream(context.Background()) if err != nil { panic(err) } diff --git a/http3/server.go b/http3/server.go index 689ec41f..80f5b1fc 100644 --- a/http3/server.go +++ b/http3/server.go @@ -138,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) { str.Write(buf.Bytes()) for { - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) if err != nil { s.logger.Debugf("Accepting stream failed: %s", err) return diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index d2756d4a..64689826 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -76,7 +76,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel around 2/3 of the streams if rand.Int31()%3 != 0 { @@ -120,7 +120,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // only read some data from about 1/3 of the streams if rand.Int31()%3 != 0 { @@ -168,7 +168,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) if err != nil { @@ -304,7 +304,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel around half of the streams if rand.Int31()%2 == 0 { @@ -383,7 +383,7 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) r := io.Reader(str) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index bc9098d1..4b96e1da 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -52,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() { ) Expect(err).ToNot(HaveOccurred()) defer cl.Close() - str, err := cl.AcceptStream() + str, err := cl.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index 829d7ec0..fc063793 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -30,7 +30,7 @@ var _ = Describe("Stream deadline tests", func() { defer GinkgoRecover() sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - serverStr, err = sess.AcceptStream() + serverStr, err = sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) _, err = serverStr.Read([]byte{0}) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index e8bdbba4..5a20ccd4 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -110,7 +110,7 @@ var _ = Describe("Drop Tests", func() { ) Expect(err).ToNot(HaveOccurred()) defer sess.Close() - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := uint8(1); i <= numMessages; i++ { b := []byte{0} diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 28dd95e8..724a72cc 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -61,7 +61,7 @@ var _ = Describe("Handshake drop tests", func() { sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) defer sess.Close() - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 6) _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) @@ -107,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() { &quic.Config{Versions: []protocol.VersionNumber{version}}, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 6) _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 17e2b90a..0192abea 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -159,7 +159,7 @@ var _ = Describe("Handshake tests", func() { errChan := make(chan error) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream() + _, err := sess.AcceptStream(context.Background()) errChan <- err }() Eventually(errChan).Should(Receive(&err)) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index f6b9f0ab..a524b29b 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -51,7 +51,7 @@ var _ = Describe("Multiplexing", func() { &quic.Config{Versions: []protocol.VersionNumber{version}}, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index b3782433..eff2108a 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -68,7 +68,7 @@ var _ = Describe("non-zero RTT", func() { &quic.Config{Versions: []protocol.VersionNumber{version}}, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 85826d6b..aae2ed74 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -64,7 +64,7 @@ var _ = Describe("Stateless Resets", func() { }, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data := make([]byte, 6) _, err = str.Read(data) diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index bc2fbb93..f308ca98 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -71,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index eaa23215..8a752e57 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() { &quic.Config{IdleTimeout: idleTimeout}, ) Expect(err).ToNot(HaveOccurred()) - strIn, err := sess.AcceptStream() + strIn, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) strOut, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() { checkTimeoutError(err) _, err = sess.OpenUniStream() checkTimeoutError(err) - _, err = sess.AcceptStream() + _, err = sess.AcceptStream(context.Background()) checkTimeoutError(err) - _, err = sess.AcceptUniStream() + _, err = sess.AcceptUniStream(context.Background()) checkTimeoutError(err) }) @@ -148,7 +148,7 @@ var _ = Describe("Timeout tests", func() { defer GinkgoRecover() sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - sess.AcceptStream() // blocks until the session is closed + sess.AcceptStream(context.Background()) // blocks until the session is closed close(serverSessionClosed) }() @@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream() + _, err := sess.AcceptStream(context.Background()) checkTimeoutError(err) close(done) }() @@ -189,7 +189,7 @@ var _ = Describe("Timeout tests", func() { defer GinkgoRecover() sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - sess.AcceptStream() // blocks until the session is closed + sess.AcceptStream(context.Background()) // blocks until the session is closed close(serverSessionClosed) }() @@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream() + _, err := sess.AcceptStream(context.Background()) checkTimeoutError(err) close(done) }() diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 67bad1d2..d528c291 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -57,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() diff --git a/interface.go b/interface.go index f8dc9471..7fa54ebf 100644 --- a/interface.go +++ b/interface.go @@ -127,11 +127,11 @@ type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. // If the session was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. - AcceptStream() (Stream, error) + AcceptStream(context.Context) (Stream, error) // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. // If the session was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. - AcceptUniStream() (ReceiveStream, error) + AcceptUniStream(context.Context) (ReceiveStream, error) // OpenStream opens a new bidirectional QUIC stream. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream. diff --git a/internal/mocks/quic/session.go b/internal/mocks/quic/session.go index ae6c6b22..5e14a4f3 100644 --- a/internal/mocks/quic/session.go +++ b/internal/mocks/quic/session.go @@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder { } // AcceptStream mocks base method -func (m *MockSession) AcceptStream() (quic_go.Stream, error) { +func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream") + ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(quic_go.Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptStream indicates an expected call of AcceptStream -func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call { +func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method -func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) { +func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream") + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(quic_go.ReceiveStream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptUniStream indicates an expected call of AcceptUniStream -func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call { +func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0) } // Close mocks base method diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 83333188..19dfbe7e 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder { } // AcceptStream mocks base method -func (m *MockQuicSession) AcceptStream() (Stream, error) { +func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream") + ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptStream indicates an expected call of AcceptStream -func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call { +func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method -func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) { +func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream") + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(ReceiveStream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptUniStream indicates an expected call of AcceptUniStream -func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call { +func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0) } // Close mocks base method diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 1f965660..03959f9a 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -5,6 +5,7 @@ package quic import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { } // AcceptStream mocks base method -func (m *MockStreamManager) AcceptStream() (Stream, error) { +func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream") + ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptStream indicates an expected call of AcceptStream -func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method -func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) { +func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream") + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(ReceiveStream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptUniStream indicates an expected call of AcceptUniStream -func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) } // CloseWithError mocks base method diff --git a/session.go b/session.go index 790f8266..87dad422 100644 --- a/session.go +++ b/session.go @@ -39,8 +39,8 @@ type streamManager interface { OpenUniStream() (SendStream, error) OpenStreamSync() (Stream, error) OpenUniStreamSync() (SendStream, error) - AcceptStream() (Stream, error) - AcceptUniStream() (ReceiveStream, error) + AcceptStream(context.Context) (Stream, error) + AcceptUniStream(context.Context) (ReceiveStream, error) DeleteStream(protocol.StreamID) error UpdateLimits(*handshake.TransportParameters) error HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error @@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) { } // AcceptStream returns the next stream openend by the peer -func (s *session) AcceptStream() (Stream, error) { - return s.streamsMap.AcceptStream() +func (s *session) AcceptStream(ctx context.Context) (Stream, error) { + return s.streamsMap.AcceptStream(ctx) } -func (s *session) AcceptUniStream() (ReceiveStream, error) { - return s.streamsMap.AcceptUniStream() +func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + return s.streamsMap.AcceptUniStream(ctx) } // OpenStream opens a stream diff --git a/session_test.go b/session_test.go index b721a21c..ddb2569c 100644 --- a/session_test.go +++ b/session_test.go @@ -367,14 +367,6 @@ var _ = Describe("Session", func() { Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) }) - It("accepts new streams", func() { - mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().AcceptStream().Return(mstr, nil) - str, err := sess.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - Context("closing", func() { var ( runErr error @@ -1454,17 +1446,21 @@ var _ = Describe("Session", func() { }) It("accepts streams", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().AcceptStream().Return(mstr, nil) - str, err := sess.AcceptStream() + streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil) + str, err := sess.AcceptStream(ctx) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) It("accepts unidirectional streams", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() mstr := NewMockReceiveStreamI(mockCtrl) - streamManager.EXPECT().AcceptUniStream().Return(mstr, nil) - str, err := sess.AcceptUniStream() + streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil) + str, err := sess.AcceptUniStream(ctx) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) diff --git a/streams_map.go b/streams_map.go index cc53efd1..56304d4b 100644 --- a/streams_map.go +++ b/streams_map.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "fmt" "net" @@ -123,13 +124,13 @@ func (m *streamsMap) OpenUniStreamSync() (SendStream, error) { return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) } -func (m *streamsMap) AcceptStream() (Stream, error) { - str, err := m.incomingBidiStreams.AcceptStream() +func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { + str, err := m.incomingBidiStreams.AcceptStream(ctx) return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) } -func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { - str, err := m.incomingUniStreams.AcceptStream() +func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + str, err := m.incomingUniStreams.AcceptStream(ctx) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) } diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index aa431558..f24b9ec2 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -5,6 +5,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -50,24 +51,28 @@ func newIncomingBidiStreamsMap( } } -func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { +func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { m.mutex.Lock() - defer m.mutex.Unlock() var num protocol.StreamNum var str streamI for { num = m.nextStreamToAccept - var ok bool if m.closeErr != nil { + m.mutex.Unlock() return nil, m.closeErr } + var ok bool str, ok = m.streams[num] if ok { break } m.mutex.Unlock() - <-m.newStreamChan + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } m.mutex.Lock() } m.nextStreamToAccept++ @@ -75,9 +80,11 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { if _, ok := m.streamsToDelete[num]; ok { delete(m.streamsToDelete, num) if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() return nil, err } } + m.mutex.Unlock() return str, nil } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index 004daed5..f8ace8bf 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -1,6 +1,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -48,24 +49,28 @@ func newIncomingItemsMap( } } -func (m *incomingItemsMap) AcceptStream() (item, error) { +func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { m.mutex.Lock() - defer m.mutex.Unlock() var num protocol.StreamNum var str item for { num = m.nextStreamToAccept - var ok bool if m.closeErr != nil { + m.mutex.Unlock() return nil, m.closeErr } + var ok bool str, ok = m.streams[num] if ok { break } m.mutex.Unlock() - <-m.newStreamChan + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } m.mutex.Lock() } m.nextStreamToAccept++ @@ -73,9 +78,11 @@ func (m *incomingItemsMap) AcceptStream() (item, error) { if _, ok := m.streamsToDelete[num]; ok { delete(m.streamsToDelete, num) if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() return nil, err } } + m.mutex.Unlock() return str, nil } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 0a59b7f5..f62e6d5c 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "github.com/golang/mock/gomock" @@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() { It("accepts streams in the right order", func() { _, err := m.GetOrOpenStream(2) // open streams 1 and 2 Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - str, err = m.AcceptStream() + str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) @@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() { strChan := make(chan item) go func() { defer GinkgoRecover() - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) strChan <- str }() @@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) }) + It("unblocks AcceptStream when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + It("unblocks AcceptStream when it is closed", func() { testErr := errors.New("test error") done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := m.AcceptStream() + _, err := m.AcceptStream(context.Background()) Expect(err).To(MatchError(testErr)) close(done) }() @@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() { It("errors AcceptStream immediately if it is closed", func() { testErr := errors.New("test error") m.CloseWithError(testErr) - _, err := m.AcceptStream() + _, err := m.AcceptStream(context.Background()) Expect(err).To(MatchError(testErr)) }) @@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) Expect(m.DeleteStream(1)).To(Succeed()) @@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() { _, err := m.GetOrOpenStream(2) Expect(err).ToNot(HaveOccurred()) Expect(m.DeleteStream(2)).To(Succeed()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued mockSender.EXPECT().queueControlFrame(gomock.Any()) - str, err = m.AcceptStream() + str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) @@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(str).To(BeNil()) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued mockSender.EXPECT().queueControlFrame(gomock.Any()) - str, err = m.AcceptStream() + str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) }) @@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(err).ToNot(HaveOccurred()) // accept all streams for i := 0; i < 5; i++ { - _, err := m.AcceptStream() + _, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) } mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 1069ba19..c146f3ab 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -5,6 +5,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -50,24 +51,28 @@ func newIncomingUniStreamsMap( } } -func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { +func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { m.mutex.Lock() - defer m.mutex.Unlock() var num protocol.StreamNum var str receiveStreamI for { num = m.nextStreamToAccept - var ok bool if m.closeErr != nil { + m.mutex.Unlock() return nil, m.closeErr } + var ok bool str, ok = m.streams[num] if ok { break } m.mutex.Unlock() - <-m.newStreamChan + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } m.mutex.Lock() } m.nextStreamToAccept++ @@ -75,9 +80,11 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { if _, ok := m.streamsToDelete[num]; ok { delete(m.streamsToDelete, num) if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() return nil, err } } + m.mutex.Unlock() return str, nil } diff --git a/streams_map_test.go b/streams_map_test.go index cbace466..7ab5bf7d 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "fmt" "net" @@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() { It("accepts bidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeAssignableToTypeOf(&stream{})) Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream)) @@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() { It("accepts unidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptUniStream() + str, err := m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeAssignableToTypeOf(&receiveStream{})) Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream)) @@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() { _, err := m.GetOrOpenReceiveStream(id) Expect(err).ToNot(HaveOccurred()) Expect(m.DeleteStream(id)).To(Succeed()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(str.StreamID()).To(Equal(id)) @@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() { _, err := m.GetOrOpenReceiveStream(id) Expect(err).ToNot(HaveOccurred()) Expect(m.DeleteStream(id)).To(Succeed()) - str, err := m.AcceptUniStream() + str, err := m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(str.StreamID()).To(Equal(id)) @@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() { It("sends a MAX_STREAMS frame for bidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) Expect(err).ToNot(HaveOccurred()) - _, err = m.AcceptStream() + _, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, @@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() { It("sends a MAX_STREAMS frame for unidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) Expect(err).ToNot(HaveOccurred()) - _, err = m.AcceptUniStream() + _, err = m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, @@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() { _, err = m.OpenUniStream() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.AcceptStream() + _, err = m.AcceptStream(context.Background()) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.AcceptUniStream() + _, err = m.AcceptUniStream(context.Background()) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) })