add a context to Session.Open{Uni}StreamSync

This commit is contained in:
Marten Seemann 2019-06-07 16:19:56 +08:00
parent e63a991950
commit 2b8cece60a
20 changed files with 218 additions and 104 deletions

View file

@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -14,10 +15,12 @@ import (
type outgoingUniStreamsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]sendStreamI
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@ -34,6 +37,7 @@ func newOutgoingUniStreamsMap(
) *outgoingUniStreamsMap {
return &outgoingUniStreamsMap{
streams: make(map[protocol.StreamNum]sendStreamI),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@ -57,7 +61,7 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
return m.openStream(), nil
}
func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@ -65,17 +69,32 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@ -86,7 +105,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@ -159,9 +178,15 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@ -172,7 +197,9 @@ func (m *outgoingUniStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}