diff --git a/http3/server.go b/http3/server.go index c17b16b5..cc8dd027 100644 --- a/http3/server.go +++ b/http3/server.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -498,6 +499,8 @@ func (s *Server) handleConn(conn quic.Connection) error { } func (s *Server) handleUnidirectionalStreams(conn quic.Connection) { + var rcvdControlStream atomic.Bool + for { str, err := conn.AcceptUniStream(context.Background()) if err != nil { @@ -531,6 +534,11 @@ func (s *Server) handleUnidirectionalStreams(conn quic.Connection) { str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) return } + // Only a single control stream is allowed. + if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr { + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + return + } f, err := parseNextFrame(str, nil) if err != nil { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") diff --git a/http3/server_test.go b/http3/server_test.go index 1e517ff5..072d3e17 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -16,6 +16,7 @@ import ( "github.com/quic-go/quic-go" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/testdata" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" @@ -497,7 +498,7 @@ var _ = Describe("Server", func() { Context("control stream handling", func() { var conn *mockquic.MockEarlyConnection - testDone := make(chan struct{}) + testDone := make(chan struct{}, 1) BeforeEach(func() { conn = mockquic.NewMockEarlyConnection(mockCtrl) @@ -528,6 +529,34 @@ var _ = Describe("Server", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) + It("rejects duplicate control streams", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{}).Append(b) + r1 := bytes.NewReader(b) + controlStr1 := mockquic.NewMockStream(mockCtrl) + controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() + r2 := bytes.NewReader(b) + controlStr2 := mockquic.NewMockStream(mockCtrl) + controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() + done := make(chan struct{}) + conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { + close(done) + return nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr1, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr2, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-done + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { streamType := t name := "encoder"