fix race when stream.Read and CancelRead are called concurrently

This commit is contained in:
Marten Seemann 2021-07-29 13:41:40 +02:00
parent 8906148682
commit fbc30cd942
2 changed files with 54 additions and 5 deletions

View file

@ -25,7 +25,8 @@ var _ = Describe("Stream Cancelations", func() {
// The server accepts a single session, and then opens numStreams unidirectional streams.
// On each of these streams, it (tries to) write PRData.
runServer := func() <-chan int32 {
// When done, it sends the number of canceled streams on the channel.
runServer := func(data []byte) <-chan int32 {
numCanceledStreamsChan := make(chan int32)
var err error
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
@ -44,7 +45,7 @@ var _ = Describe("Stream Cancelations", func() {
defer wg.Done()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
if _, err := str.Write(PRData); err != nil {
if _, err := str.Write(data); err != nil {
Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
@ -70,7 +71,7 @@ var _ = Describe("Stream Cancelations", func() {
})
It("downloads when the client immediately cancels most streams", func() {
serverCanceledCounterChan := runServer()
serverCanceledCounterChan := runServer(PRData)
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@ -113,7 +114,7 @@ var _ = Describe("Stream Cancelations", func() {
})
It("downloads when the client cancels streams after reading from them for a bit", func() {
serverCanceledCounterChan := runServer()
serverCanceledCounterChan := runServer(PRData)
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@ -159,6 +160,51 @@ var _ = Describe("Stream Cancelations", func() {
Expect(clientCanceledCounter).To(BeNumerically(">", numStreams/10))
Expect(numStreams - clientCanceledCounter).To(BeNumerically(">", numStreams/10))
})
It("allows concurrent Read and CancelRead calls", func() {
// This test is especially valuable when run with race detector,
// see https://github.com/lucas-clemente/quic-go/issues/3239.
serverCanceledCounterChan := runServer(make([]byte, 100)) // make sure the FIN is sent with the STREAM frame
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
)
Expect(err).ToNot(HaveOccurred())
var wg sync.WaitGroup
wg.Add(numStreams)
var counter int32
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
b := make([]byte, 32)
if _, err := str.Read(b); err != nil {
atomic.AddInt32(&counter, 1)
Expect(err.Error()).To(ContainSubstring("canceled with error code 1234"))
return
}
}()
go str.CancelRead(1234)
Eventually(done).Should(BeClosed())
}()
}
wg.Wait()
Expect(sess.CloseWithError(0, "")).To(Succeed())
numCanceled := atomic.LoadInt32(&counter)
fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams)
Expect(numCanceled).ToNot(BeZero())
Eventually(serverCanceledCounterChan).Should(Receive())
})
})
Context("canceling the write side", func() {

View file

@ -109,7 +109,10 @@ func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
}
func (c *streamFlowController) Abandon() {
if unread := c.highestReceived - c.bytesRead; unread > 0 {
c.mutex.Lock()
unread := c.highestReceived - c.bytesRead
c.mutex.Unlock()
if unread > 0 {
c.connection.AddBytesRead(unread)
}
}