mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
pass an error code to Session.CloseWithError
This commit is contained in:
parent
8b2992a243
commit
6d3d6dc47d
7 changed files with 27 additions and 22 deletions
|
@ -172,7 +172,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
responseChan := make(chan *http.Response)
|
responseChan := make(chan *http.Response)
|
||||||
dataStream, err := c.session.OpenStreamSync()
|
dataStream, err := c.session.OpenStreamSync()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.CloseWithError(err)
|
_ = c.closeWithError(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
|
@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
endStream := !hasBody
|
endStream := !hasBody
|
||||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.CloseWithError(err)
|
_ = c.closeWithError(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,7 +230,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case <-c.headerErrored:
|
case <-c.headerErrored:
|
||||||
// an error occurred on the header stream
|
// an error occurred on the header stream
|
||||||
_ = c.CloseWithError(c.headerErr)
|
_ = c.closeWithError(c.headerErr)
|
||||||
return nil, c.headerErr
|
return nil, c.headerErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -275,14 +275,14 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
|
||||||
return dataStream.Close()
|
return dataStream.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the client
|
func (c *client) closeWithError(e error) error {
|
||||||
func (c *client) CloseWithError(e error) error {
|
|
||||||
if c.session == nil {
|
if c.session == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.session.CloseWithError(e)
|
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the client
|
||||||
func (c *client) Close() error {
|
func (c *client) Close() error {
|
||||||
if c.session == nil {
|
if c.session == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
||||||
func (s *Server) handleHeaderStream(session streamCreator) {
|
func (s *Server) handleHeaderStream(session streamCreator) {
|
||||||
stream, err := session.AcceptStream()
|
stream, err := session.AcceptStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
session.CloseWithError(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,10 +140,12 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
||||||
// QuicErrors must originate from stream.Read() returning an error.
|
// QuicErrors must originate from stream.Read() returning an error.
|
||||||
// In this case, the session has already logged the error, so we don't
|
// In this case, the session has already logged the error, so we don't
|
||||||
// need to log it again.
|
// need to log it again.
|
||||||
if _, ok := err.(*qerr.QuicError); !ok {
|
errorCode := qerr.InternalError
|
||||||
|
if qerr, ok := err.(*qerr.QuicError); !ok {
|
||||||
|
errorCode = qerr.ErrorCode
|
||||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
||||||
}
|
}
|
||||||
session.CloseWithError(err)
|
session.CloseWithError(quic.ErrorCode(errorCode), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ func (s *mockSession) Close() error {
|
||||||
s.closed = true
|
s.closed = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (s *mockSession) CloseWithError(e error) error {
|
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
|
||||||
s.closedWithError = e
|
s.closedWithError = e
|
||||||
return s.Close()
|
return s.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,8 @@ type Session interface {
|
||||||
// Close the connection.
|
// Close the connection.
|
||||||
io.Closer
|
io.Closer
|
||||||
// Close the connection with an error.
|
// Close the connection with an error.
|
||||||
CloseWithError(error) error
|
// The error must not be nil.
|
||||||
|
CloseWithError(ErrorCode, error) error
|
||||||
// The context is cancelled when the session is closed.
|
// The context is cancelled when the session is closed.
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
|
|
|
@ -76,15 +76,15 @@ func (mr *MockQuicSessionMockRecorder) Close() *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseWithError mocks base method
|
// CloseWithError mocks base method
|
||||||
func (m *MockQuicSession) CloseWithError(arg0 error) error {
|
func (m *MockQuicSession) CloseWithError(arg0 protocol.ApplicationErrorCode, arg1 error) error {
|
||||||
ret := m.ctrl.Call(m, "CloseWithError", arg0)
|
ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseWithError indicates an expected call of CloseWithError
|
// CloseWithError indicates an expected call of CloseWithError
|
||||||
func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call {
|
func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionState mocks base method
|
// ConnectionState mocks base method
|
||||||
|
|
|
@ -853,11 +853,13 @@ func (s *session) closeRemote(e error) {
|
||||||
// Close the connection. It sends a qerr.PeerGoingAway.
|
// Close the connection. It sends a qerr.PeerGoingAway.
|
||||||
// It waits until the run loop has stopped before returning
|
// It waits until the run loop has stopped before returning
|
||||||
func (s *session) Close() error {
|
func (s *session) Close() error {
|
||||||
return s.CloseWithError(nil)
|
s.closeLocal(nil)
|
||||||
|
<-s.ctx.Done()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) CloseWithError(e error) error {
|
func (s *session) CloseWithError(code protocol.ApplicationErrorCode, e error) error {
|
||||||
s.closeLocal(e)
|
s.closeLocal(qerr.Error(qerr.ErrorCode(code), e.Error()))
|
||||||
<-s.ctx.Done()
|
<-s.ctx.Done()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -535,9 +535,9 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
It("closes streams with proper error", func() {
|
It("closes streams with proper error", func() {
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error()))
|
streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error()))
|
||||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||||
sess.CloseWithError(testErr)
|
sess.CloseWithError(0x1337, testErr)
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
Eventually(areSessionsRunning).Should(BeFalse())
|
||||||
Expect(sess.Context().Done()).To(BeClosed())
|
Expect(sess.Context().Done()).To(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -1476,12 +1476,12 @@ var _ = Describe("Session", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
err := sess.run()
|
err := sess.run()
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(qerr.Error(0x1337, testErr.Error())))
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||||
Expect(sess.CloseWithError(testErr)).To(Succeed())
|
Expect(sess.CloseWithError(0x1337, testErr)).To(Succeed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue