diff --git a/h2quic/client.go b/h2quic/client.go index 9ab7a65e..ac28a7f0 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -172,7 +172,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { responseChan := make(chan *http.Response) dataStream, err := c.session.OpenStreamSync() if err != nil { - _ = c.CloseWithError(err) + _ = c.closeWithError(err) return nil, err } c.mutex.Lock() @@ -187,7 +187,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { endStream := !hasBody err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) if err != nil { - _ = c.CloseWithError(err) + _ = c.closeWithError(err) return nil, err } @@ -230,7 +230,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, ctx.Err() case <-c.headerErrored: // an error occurred on the header stream - _ = c.CloseWithError(c.headerErr) + _ = c.closeWithError(c.headerErr) return nil, c.headerErr } } @@ -275,14 +275,14 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e return dataStream.Close() } -// Close closes the client -func (c *client) CloseWithError(e error) error { +func (c *client) closeWithError(e error) error { if c.session == nil { return nil } - return c.session.CloseWithError(e) + return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e) } +// Close closes the client func (c *client) Close() error { if c.session == nil { return nil diff --git a/h2quic/server.go b/h2quic/server.go index 2b76b0c4..0d787bdd 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -127,7 +127,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { func (s *Server) handleHeaderStream(session streamCreator) { stream, err := session.AcceptStream() if err != nil { - session.CloseWithError(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) + session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err) return } @@ -140,10 +140,12 @@ func (s *Server) handleHeaderStream(session streamCreator) { // QuicErrors must originate from stream.Read() returning an error. // In this case, the session has already logged the error, so we don't // need to log it again. - if _, ok := err.(*qerr.QuicError); !ok { + errorCode := qerr.InternalError + if qerr, ok := err.(*qerr.QuicError); !ok { + errorCode = qerr.ErrorCode s.logger.Errorf("error handling h2 request: %s", err.Error()) } - session.CloseWithError(err) + session.CloseWithError(quic.ErrorCode(errorCode), err) return } } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 69a8c96c..6125f2bc 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -69,7 +69,7 @@ func (s *mockSession) Close() error { s.closed = true return nil } -func (s *mockSession) CloseWithError(e error) error { +func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error { s.closedWithError = e return s.Close() } diff --git a/interface.go b/interface.go index 3eac4c7b..fa1e19c5 100644 --- a/interface.go +++ b/interface.go @@ -148,7 +148,8 @@ type Session interface { // Close the connection. io.Closer // Close the connection with an error. - CloseWithError(error) error + // The error must not be nil. + CloseWithError(ErrorCode, error) error // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. Context() context.Context diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index eb3e42e5..3cbb6383 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -76,15 +76,15 @@ func (mr *MockQuicSessionMockRecorder) Close() *gomock.Call { } // CloseWithError mocks base method -func (m *MockQuicSession) CloseWithError(arg0 error) error { - ret := m.ctrl.Call(m, "CloseWithError", arg0) +func (m *MockQuicSession) CloseWithError(arg0 protocol.ApplicationErrorCode, arg1 error) error { + ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // CloseWithError indicates an expected call of CloseWithError -func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0) +func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0, arg1) } // ConnectionState mocks base method diff --git a/session.go b/session.go index 4ed42d28..65032f62 100644 --- a/session.go +++ b/session.go @@ -853,11 +853,13 @@ func (s *session) closeRemote(e error) { // Close the connection. It sends a qerr.PeerGoingAway. // It waits until the run loop has stopped before returning func (s *session) Close() error { - return s.CloseWithError(nil) + s.closeLocal(nil) + <-s.ctx.Done() + return nil } -func (s *session) CloseWithError(e error) error { - s.closeLocal(e) +func (s *session) CloseWithError(code protocol.ApplicationErrorCode, e error) error { + s.closeLocal(qerr.Error(qerr.ErrorCode(code), e.Error())) <-s.ctx.Done() return nil } diff --git a/session_test.go b/session_test.go index cc5820d7..9c13766a 100644 --- a/session_test.go +++ b/session_test.go @@ -535,9 +535,9 @@ var _ = Describe("Session", func() { It("closes streams with proper error", func() { testErr := errors.New("test error") - streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) + streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error())) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.CloseWithError(testErr) + sess.CloseWithError(0x1337, testErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) }) @@ -1476,12 +1476,12 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() err := sess.run() - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(qerr.Error(0x1337, testErr.Error()))) close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - Expect(sess.CloseWithError(testErr)).To(Succeed()) + Expect(sess.CloseWithError(0x1337, testErr)).To(Succeed()) Eventually(done).Should(BeClosed()) })