pass an error code to Session.CloseWithError

This commit is contained in:
Marten Seemann 2018-07-04 17:47:10 +07:00
parent 8b2992a243
commit 6d3d6dc47d
7 changed files with 27 additions and 22 deletions

View file

@ -172,7 +172,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
responseChan := make(chan *http.Response) responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync() dataStream, err := c.session.OpenStreamSync()
if err != nil { if err != nil {
_ = c.CloseWithError(err) _ = c.closeWithError(err)
return nil, err return nil, err
} }
c.mutex.Lock() c.mutex.Lock()
@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
_ = c.CloseWithError(err) _ = c.closeWithError(err)
return nil, err return nil, err
} }
@ -230,7 +230,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, ctx.Err() return nil, ctx.Err()
case <-c.headerErrored: case <-c.headerErrored:
// an error occurred on the header stream // an error occurred on the header stream
_ = c.CloseWithError(c.headerErr) _ = c.closeWithError(c.headerErr)
return nil, c.headerErr return nil, c.headerErr
} }
} }
@ -275,14 +275,14 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
return dataStream.Close() return dataStream.Close()
} }
// Close closes the client func (c *client) closeWithError(e error) error {
func (c *client) CloseWithError(e error) error {
if c.session == nil { if c.session == nil {
return nil return nil
} }
return c.session.CloseWithError(e) return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
} }
// Close closes the client
func (c *client) Close() error { func (c *client) Close() error {
if c.session == nil { if c.session == nil {
return nil return nil

View file

@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
func (s *Server) handleHeaderStream(session streamCreator) { func (s *Server) handleHeaderStream(session streamCreator) {
stream, err := session.AcceptStream() stream, err := session.AcceptStream()
if err != nil { if err != nil {
session.CloseWithError(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
return return
} }
@ -140,10 +140,12 @@ func (s *Server) handleHeaderStream(session streamCreator) {
// QuicErrors must originate from stream.Read() returning an error. // QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't // In this case, the session has already logged the error, so we don't
// need to log it again. // need to log it again.
if _, ok := err.(*qerr.QuicError); !ok { errorCode := qerr.InternalError
if qerr, ok := err.(*qerr.QuicError); !ok {
errorCode = qerr.ErrorCode
s.logger.Errorf("error handling h2 request: %s", err.Error()) s.logger.Errorf("error handling h2 request: %s", err.Error())
} }
session.CloseWithError(err) session.CloseWithError(quic.ErrorCode(errorCode), err)
return return
} }
} }

View file

@ -69,7 +69,7 @@ func (s *mockSession) Close() error {
s.closed = true s.closed = true
return nil return nil
} }
func (s *mockSession) CloseWithError(e error) error { func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
s.closedWithError = e s.closedWithError = e
return s.Close() return s.Close()
} }

View file

@ -148,7 +148,8 @@ type Session interface {
// Close the connection. // Close the connection.
io.Closer io.Closer
// Close the connection with an error. // Close the connection with an error.
CloseWithError(error) error // The error must not be nil.
CloseWithError(ErrorCode, error) error
// The context is cancelled when the session is closed. // The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
Context() context.Context Context() context.Context

View file

@ -76,15 +76,15 @@ func (mr *MockQuicSessionMockRecorder) Close() *gomock.Call {
} }
// CloseWithError mocks base method // CloseWithError mocks base method
func (m *MockQuicSession) CloseWithError(arg0 error) error { func (m *MockQuicSession) CloseWithError(arg0 protocol.ApplicationErrorCode, arg1 error) error {
ret := m.ctrl.Call(m, "CloseWithError", arg0) ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// CloseWithError indicates an expected call of CloseWithError // CloseWithError indicates an expected call of CloseWithError
func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0, arg1)
} }
// ConnectionState mocks base method // ConnectionState mocks base method

View file

@ -853,11 +853,13 @@ func (s *session) closeRemote(e error) {
// Close the connection. It sends a qerr.PeerGoingAway. // Close the connection. It sends a qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning // It waits until the run loop has stopped before returning
func (s *session) Close() error { func (s *session) Close() error {
return s.CloseWithError(nil) s.closeLocal(nil)
<-s.ctx.Done()
return nil
} }
func (s *session) CloseWithError(e error) error { func (s *session) CloseWithError(code protocol.ApplicationErrorCode, e error) error {
s.closeLocal(e) s.closeLocal(qerr.Error(qerr.ErrorCode(code), e.Error()))
<-s.ctx.Done() <-s.ctx.Done()
return nil return nil
} }

View file

@ -535,9 +535,9 @@ var _ = Describe("Session", func() {
It("closes streams with proper error", func() { It("closes streams with proper error", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error()))
sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any())
sess.CloseWithError(testErr) sess.CloseWithError(0x1337, testErr)
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
Expect(sess.Context().Done()).To(BeClosed()) Expect(sess.Context().Done()).To(BeClosed())
}) })
@ -1476,12 +1476,12 @@ var _ = Describe("Session", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := sess.run() err := sess.run()
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(qerr.Error(0x1337, testErr.Error())))
close(done) close(done)
}() }()
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.CloseWithError(testErr)).To(Succeed()) Expect(sess.CloseWithError(0x1337, testErr)).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })