use the stream helper function in the streamsMap

This commit is contained in:
Marten Seemann 2018-10-31 10:17:59 +07:00
parent 5768b492d7
commit 44243b4f52

View file

@ -9,15 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
type streamType int
const (
streamTypeOutgoingBidi streamType = iota
streamTypeIncomingBidi
streamTypeOutgoingUni
streamTypeIncomingUni
)
type streamsMap struct { type streamsMap struct {
perspective protocol.Perspective perspective protocol.Perspective
@ -93,33 +84,6 @@ func newStreamsMap(
return m return m
} }
func (m *streamsMap) getStreamType(id protocol.StreamID) streamType {
if m.perspective == protocol.PerspectiveServer {
switch id % 4 {
case 0:
return streamTypeIncomingBidi
case 1:
return streamTypeOutgoingBidi
case 2:
return streamTypeIncomingUni
case 3:
return streamTypeOutgoingUni
}
} else {
switch id % 4 {
case 0:
return streamTypeOutgoingBidi
case 1:
return streamTypeIncomingBidi
case 2:
return streamTypeOutgoingUni
case 3:
return streamTypeIncomingUni
}
}
panic("")
}
func (m *streamsMap) OpenStream() (Stream, error) { func (m *streamsMap) OpenStream() (Stream, error) {
return m.outgoingBidiStreams.OpenStream() return m.outgoingBidiStreams.OpenStream()
} }
@ -145,64 +109,67 @@ func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
} }
func (m *streamsMap) DeleteStream(id protocol.StreamID) error { func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
switch m.getStreamType(id) { switch id.Type() {
case streamTypeIncomingBidi: case protocol.StreamTypeUni:
return m.incomingBidiStreams.DeleteStream(id) if id.InitiatedBy() == m.perspective {
case streamTypeOutgoingBidi:
return m.outgoingBidiStreams.DeleteStream(id)
case streamTypeIncomingUni:
return m.incomingUniStreams.DeleteStream(id)
case streamTypeOutgoingUni:
return m.outgoingUniStreams.DeleteStream(id) return m.outgoingUniStreams.DeleteStream(id)
default:
panic("invalid stream type")
} }
return m.incomingUniStreams.DeleteStream(id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
return m.outgoingBidiStreams.DeleteStream(id)
}
return m.incomingBidiStreams.DeleteStream(id)
}
panic("")
} }
func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
switch m.getStreamType(id) { switch id.Type() {
case streamTypeOutgoingBidi: case protocol.StreamTypeUni:
return m.outgoingBidiStreams.GetStream(id) if id.InitiatedBy() == m.perspective {
case streamTypeIncomingBidi:
return m.incomingBidiStreams.GetOrOpenStream(id)
case streamTypeIncomingUni:
return m.incomingUniStreams.GetOrOpenStream(id)
case streamTypeOutgoingUni:
// an outgoing unidirectional stream is a send stream, not a receive stream // an outgoing unidirectional stream is a send stream, not a receive stream
return nil, fmt.Errorf("peer attempted to open receive stream %d", id) return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
default:
panic("invalid stream type")
} }
return m.incomingUniStreams.GetOrOpenStream(id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
return m.outgoingBidiStreams.GetStream(id)
}
return m.incomingBidiStreams.GetOrOpenStream(id)
}
panic("")
} }
func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
switch m.getStreamType(id) { switch id.Type() {
case streamTypeOutgoingBidi: case protocol.StreamTypeUni:
return m.outgoingBidiStreams.GetStream(id) if id.InitiatedBy() == m.perspective {
case streamTypeIncomingBidi:
return m.incomingBidiStreams.GetOrOpenStream(id)
case streamTypeOutgoingUni:
return m.outgoingUniStreams.GetStream(id) return m.outgoingUniStreams.GetStream(id)
case streamTypeIncomingUni: }
// an incoming unidirectional stream is a receive stream, not a send stream // an incoming unidirectional stream is a receive stream, not a send stream
return nil, fmt.Errorf("peer attempted to open send stream %d", id) return nil, fmt.Errorf("peer attempted to open send stream %d", id)
default: case protocol.StreamTypeBidi:
panic("invalid stream type") if id.InitiatedBy() == m.perspective {
return m.outgoingBidiStreams.GetStream(id)
} }
return m.incomingBidiStreams.GetOrOpenStream(id)
}
panic("")
} }
func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
id := f.StreamID id := f.StreamID
switch m.getStreamType(id) { if id.InitiatedBy() != m.perspective {
case streamTypeOutgoingBidi:
m.outgoingBidiStreams.SetMaxStream(id)
return nil
case streamTypeOutgoingUni:
m.outgoingUniStreams.SetMaxStream(id)
return nil
default:
return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
} }
switch id.Type() {
case protocol.StreamTypeUni:
m.outgoingUniStreams.SetMaxStream(id)
case protocol.StreamTypeBidi:
m.outgoingBidiStreams.SetMaxStream(id)
}
return nil
} }
func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) { func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {