mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
cut coalesed packets in the session
This commit is contained in:
parent
df34e4496e
commit
02e851bd11
12 changed files with 442 additions and 577 deletions
12
client.go
12
client.go
|
@ -288,7 +288,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
|||
|
||||
func (c *client) handlePacket(p *receivedPacket) {
|
||||
if wire.IsVersionNegotiationPacket(p.data) {
|
||||
go c.handleVersionNegotiationPacket(p.hdr)
|
||||
go c.handleVersionNegotiationPacket(p)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -301,10 +301,16 @@ func (c *client) handlePacket(p *receivedPacket) {
|
|||
c.session.handlePacket(p)
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
|
||||
func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
if err != nil {
|
||||
c.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
|
||||
c.logger.Debugf("Received a delayed Version Negotiation packet.")
|
||||
|
@ -403,6 +409,6 @@ func (c *client) GetVersion() protocol.VersionNumber {
|
|||
return v
|
||||
}
|
||||
|
||||
func (c *client) GetPerspective() protocol.Perspective {
|
||||
func (c *client) getPerspective() protocol.Perspective {
|
||||
return protocol.PerspectiveClient
|
||||
}
|
||||
|
|
|
@ -58,12 +58,9 @@ var _ = Describe("Client", func() {
|
|||
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
|
||||
data, err := wire.ComposeVersionNegotiation(connID, nil, versions)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr, _, _, err := wire.ParsePacket(data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue())
|
||||
return &receivedPacket{
|
||||
rcvTime: time.Now(),
|
||||
hdr: hdr,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
@ -543,19 +540,22 @@ var _ = Describe("Client", func() {
|
|||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("recognizes that a non version negotiation packet means that the server accepted the suggested version", func() {
|
||||
It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = &Config{}
|
||||
cl.handlePacket(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
Version: cl.version,
|
||||
},
|
||||
})
|
||||
Eventually(cl.versionNegotiated.Get()).Should(BeTrue())
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
||||
Eventually(cl.versionNegotiated.Get).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
|
|
|
@ -46,18 +46,6 @@ func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close))
|
||||
}
|
||||
|
||||
// GetPerspective mocks base method
|
||||
func (m *MockPacketHandler) GetPerspective() protocol.Perspective {
|
||||
ret := m.ctrl.Call(m, "GetPerspective")
|
||||
ret0, _ := ret[0].(protocol.Perspective)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPerspective indicates an expected call of GetPerspective
|
||||
func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective))
|
||||
}
|
||||
|
||||
// destroy mocks base method
|
||||
func (m *MockPacketHandler) destroy(arg0 error) {
|
||||
m.ctrl.Call(m, "destroy", arg0)
|
||||
|
@ -68,6 +56,18 @@ func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0)
|
||||
}
|
||||
|
||||
// getPerspective mocks base method
|
||||
func (m *MockPacketHandler) getPerspective() protocol.Perspective {
|
||||
ret := m.ctrl.Call(m, "getPerspective")
|
||||
ret0, _ := ret[0].(protocol.Perspective)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// getPerspective indicates an expected call of getPerspective
|
||||
func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockPacketHandler)(nil).getPerspective))
|
||||
}
|
||||
|
||||
// handlePacket mocks base method
|
||||
func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) {
|
||||
m.ctrl.Call(m, "handlePacket", arg0)
|
||||
|
|
|
@ -231,6 +231,18 @@ func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0)
|
||||
}
|
||||
|
||||
// getPerspective mocks base method
|
||||
func (m *MockQuicSession) getPerspective() protocol.Perspective {
|
||||
ret := m.ctrl.Call(m, "getPerspective")
|
||||
ret0, _ := ret[0].(protocol.Perspective)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// getPerspective indicates an expected call of getPerspective
|
||||
func (mr *MockQuicSessionMockRecorder) getPerspective() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicSession)(nil).getPerspective))
|
||||
}
|
||||
|
||||
// handlePacket mocks base method
|
||||
func (m *MockQuicSession) handlePacket(arg0 *receivedPacket) {
|
||||
m.ctrl.Call(m, "handlePacket", arg0)
|
||||
|
|
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -105,7 +104,7 @@ func (h *packetHandlerMap) CloseServer() {
|
|||
var wg sync.WaitGroup
|
||||
for id, handlerEntry := range h.handlers {
|
||||
handler := handlerEntry.handler
|
||||
if handler.GetPerspective() == protocol.PerspectiveServer {
|
||||
if handler.getPerspective() == protocol.PerspectiveServer {
|
||||
wg.Add(1)
|
||||
go func(id string, handler packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
|
@ -174,93 +173,46 @@ func (h *packetHandlerMap) handlePacket(
|
|||
buffer *packetBuffer,
|
||||
data []byte,
|
||||
) {
|
||||
packets, err := h.parsePacket(addr, buffer, data)
|
||||
connID, err := wire.ParseConnectionID(data, h.connIDLen)
|
||||
if err != nil {
|
||||
h.logger.Debugf("error parsing packets from %s: %s", addr, err)
|
||||
// This is just the error from parsing the last packet.
|
||||
// We still need to process the packets that were successfully parsed before.
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
buffer.Release()
|
||||
h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err)
|
||||
return
|
||||
}
|
||||
h.handleParsedPackets(packets)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) parsePacket(
|
||||
addr net.Addr,
|
||||
buffer *packetBuffer,
|
||||
data []byte,
|
||||
) ([]*receivedPacket, error) {
|
||||
rcvTime := time.Now()
|
||||
packets := make([]*receivedPacket, 0, 1)
|
||||
|
||||
var counter int
|
||||
var lastConnID protocol.ConnectionID
|
||||
for len(data) > 0 {
|
||||
hdr, packetData, rest, err := wire.ParsePacket(data, h.connIDLen)
|
||||
if err != nil {
|
||||
return packets, fmt.Errorf("error parsing packet: %s", err)
|
||||
}
|
||||
|
||||
if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
|
||||
return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
|
||||
}
|
||||
lastConnID = hdr.DestConnectionID
|
||||
|
||||
if counter > 0 {
|
||||
buffer.Split()
|
||||
}
|
||||
counter++
|
||||
packets = append(packets, &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
hdr: hdr,
|
||||
rcvTime: rcvTime,
|
||||
data: packetData,
|
||||
buffer: buffer,
|
||||
})
|
||||
|
||||
// only log if this actually a coalesced packet
|
||||
if h.logger.Debug() && (counter > 1 || len(rest) > 0) {
|
||||
h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packets[counter-1].data), len(rest))
|
||||
}
|
||||
|
||||
data = rest
|
||||
}
|
||||
return packets, nil
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
// coalesced packets all have the same destination connection ID
|
||||
handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)]
|
||||
handlerEntry, handlerFound := h.handlers[string(connID)]
|
||||
|
||||
for _, p := range packets {
|
||||
if handlerFound { // existing session
|
||||
handlerEntry.handler.handlePacket(p)
|
||||
continue
|
||||
}
|
||||
// No session found.
|
||||
// This might be a stateless reset.
|
||||
if !p.hdr.IsLongHeader {
|
||||
if len(p.data) >= protocol.MinStatelessResetSize {
|
||||
var token [16]byte
|
||||
copy(token[:], p.data[len(p.data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
sess.destroy(errors.New("received a stateless reset"))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// TODO(#943): send a stateless reset
|
||||
h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
|
||||
break // a short header packet is always the last in a coalesced packet
|
||||
}
|
||||
if h.server == nil { // no server set
|
||||
h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
|
||||
continue
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
p := &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
rcvTime: rcvTime,
|
||||
buffer: buffer,
|
||||
data: data,
|
||||
}
|
||||
if handlerFound { // existing session
|
||||
handlerEntry.handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
// No session found.
|
||||
// This might be a stateless reset.
|
||||
if data[0]&0x80 == 0 { // stateless resets are always short header packets
|
||||
if len(p.data) >= protocol.MinStatelessResetSize {
|
||||
var token [16]byte
|
||||
copy(token[:], p.data[len(p.data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
sess.destroy(errors.New("received a stateless reset"))
|
||||
return
|
||||
}
|
||||
}
|
||||
// TODO(#943): send a stateless reset
|
||||
h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
if h.server == nil { // no server set
|
||||
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package quic
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -88,11 +87,15 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
handledPacket1 := make(chan struct{})
|
||||
handledPacket2 := make(chan struct{})
|
||||
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID1))
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID1))
|
||||
close(handledPacket1)
|
||||
})
|
||||
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID2))
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID2))
|
||||
close(handledPacket2)
|
||||
})
|
||||
handler.Add(connID1, packetHandler1)
|
||||
|
@ -105,12 +108,10 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
_, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error parsing packet:"))
|
||||
handler.handlePacket(nil, nil, []byte{0, 1, 2, 3})
|
||||
})
|
||||
|
||||
It("deletes removed session immediately", func() {
|
||||
It("deletes removed sessions immediately", func() {
|
||||
handler.deleteRetiredSessionsAfter = time.Hour
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
|
@ -159,64 +160,6 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
conn.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("coalesced packets", func() {
|
||||
It("cuts packets to the right length", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...)
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen())))
|
||||
})
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.handlePacket(nil, nil, data)
|
||||
})
|
||||
|
||||
It("handles coalesced packets", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handledPackets := make(chan *receivedPacket, 3)
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
handledPackets <- p
|
||||
}).Times(3)
|
||||
handler.Add(connID, packetHandler)
|
||||
|
||||
buffer := getPacketBuffer()
|
||||
packet := buffer.Slice[:0]
|
||||
packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...)
|
||||
packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...)
|
||||
packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...)
|
||||
conn.dataToRead <- packet
|
||||
|
||||
now := time.Now()
|
||||
for i := 1; i <= 3; i++ {
|
||||
var p *receivedPacket
|
||||
Eventually(handledPackets).Should(Receive(&p))
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
||||
Expect(p.hdr.Length).To(BeEquivalentTo(10 * i))
|
||||
Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length)))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond)))
|
||||
Expect(p.buffer.refCount).To(Equal(3))
|
||||
}
|
||||
})
|
||||
|
||||
It("ignores coalesced packet parts if the connection IDs don't match", func() {
|
||||
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
|
||||
buffer := getPacketBuffer()
|
||||
packet := buffer.Slice[:0]
|
||||
// var packet []byte
|
||||
packet = append(packet, getPacket(connID1)...)
|
||||
packet = append(packet, getPacket(connID2)...)
|
||||
|
||||
packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet)
|
||||
Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708"))
|
||||
Expect(packets).To(HaveLen(1))
|
||||
Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1))
|
||||
Expect(packets[0].buffer.refCount).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("stateless reset handling", func() {
|
||||
|
@ -228,7 +171,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
// first send a normal packet
|
||||
handledPacket := make(chan struct{})
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
||||
cid, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cid).To(Equal(connID))
|
||||
close(handledPacket)
|
||||
})
|
||||
conn.dataToRead <- getPacket(connID)
|
||||
|
@ -250,24 +195,6 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("detects a stateless that is coalesced with another packet", func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
|
||||
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddWithResetToken(connID, packetHandler, token)
|
||||
fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
packet := getPacket(fakeConnID)
|
||||
reset := append([]byte{0x40} /* short header packet */, fakeConnID...)
|
||||
reset = append(reset, make([]byte, 50)...) // add some "random" data
|
||||
reset = append(reset, token[:]...)
|
||||
destroyed := make(chan struct{})
|
||||
packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) {
|
||||
close(destroyed)
|
||||
})
|
||||
conn.dataToRead <- append(packet, reset...)
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("deletes reset tokens when the session is retired", func() {
|
||||
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
|
||||
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
|
||||
|
@ -291,7 +218,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
||||
cid, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cid).To(Equal(connID))
|
||||
})
|
||||
handler.SetServer(server)
|
||||
handler.handlePacket(nil, nil, p)
|
||||
|
@ -299,9 +228,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
|
||||
It("closes all server sessions", func() {
|
||||
clientSess := NewMockPacketHandler(mockCtrl)
|
||||
clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient)
|
||||
clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||
serverSess := NewMockPacketHandler(mockCtrl)
|
||||
serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer)
|
||||
serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||
serverSess.EXPECT().Close()
|
||||
|
||||
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
|
||||
|
|
86
server.go
86
server.go
|
@ -23,7 +23,7 @@ type packetHandler interface {
|
|||
handlePacket(*receivedPacket)
|
||||
io.Closer
|
||||
destroy(error)
|
||||
GetPerspective() protocol.Perspective
|
||||
getPerspective() protocol.Perspective
|
||||
}
|
||||
|
||||
type unknownPacketHandler interface {
|
||||
|
@ -44,6 +44,7 @@ type quicSession interface {
|
|||
Session
|
||||
handlePacket(*receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
getPerspective() protocol.Perspective
|
||||
run() error
|
||||
destroy(error)
|
||||
closeForRecreating() protocol.PacketNumber
|
||||
|
@ -324,53 +325,60 @@ func (s *server) Addr() net.Addr {
|
|||
}
|
||||
|
||||
func (s *server) handlePacket(p *receivedPacket) {
|
||||
hdr := p.hdr
|
||||
|
||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||
go s.sendVersionNegotiationPacket(p)
|
||||
return
|
||||
}
|
||||
if hdr.Type == protocol.PacketTypeInitial {
|
||||
go s.handleInitial(p)
|
||||
return
|
||||
}
|
||||
|
||||
defer p.buffer.Release()
|
||||
// Drop long header packets.
|
||||
// There's litte point in sending a Stateless Reset, since the client
|
||||
// might not have received the token yet.
|
||||
if hdr.IsLongHeader {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer {
|
||||
p.buffer.Release()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *server) handleInitial(p *receivedPacket) {
|
||||
s.logger.Debugf("<- Received Initial packet.")
|
||||
sess, connID, err := s.handleInitialImpl(p)
|
||||
func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ {
|
||||
if len(p.data) < protocol.MinInitialPacketSize {
|
||||
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data))
|
||||
return false
|
||||
}
|
||||
// If we're creating a new session, the packet will be passed to the session.
|
||||
// The header will then be parsed again.
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
s.logger.Debugf("Error parsing packet: %s", err)
|
||||
return false
|
||||
}
|
||||
if !hdr.IsLongHeader {
|
||||
// TODO: send a stateless reset
|
||||
return false
|
||||
}
|
||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||
s.sendVersionNegotiationPacket(p, hdr)
|
||||
return false
|
||||
}
|
||||
if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial {
|
||||
// Drop long header packets.
|
||||
// There's litte point in sending a Stateless Reset, since the client
|
||||
// might not have received the token yet.
|
||||
return false
|
||||
}
|
||||
|
||||
s.logger.Debugf("<- Received Initial packet.")
|
||||
|
||||
sess, connID, err := s.handleInitialImpl(p, hdr)
|
||||
if err != nil {
|
||||
p.buffer.Release()
|
||||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||
return
|
||||
return false
|
||||
}
|
||||
if sess == nil { // a retry was done, or the connection attempt was rejected
|
||||
p.buffer.Release()
|
||||
return
|
||||
return false
|
||||
}
|
||||
// Don't put the packet buffer back if a new session was created.
|
||||
// The session will handle the packet and take of that.
|
||||
serverSession := newServerSession(sess, s.config, s.logger)
|
||||
s.sessionHandler.Add(connID, serverSession)
|
||||
s.sessionHandler.Add(connID, sess)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
|
||||
hdr := p.hdr
|
||||
func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) {
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
|
||||
}
|
||||
if len(p.data) < protocol.MinInitialPacketSize {
|
||||
return nil, nil, errors.New("dropping too small Initial packet")
|
||||
return nil, nil, errors.New("too short connection ID")
|
||||
}
|
||||
|
||||
var cookie *Cookie
|
||||
|
@ -388,7 +396,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
|
|||
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
|
||||
// Log the Initial packet now.
|
||||
// If no Retry is sent, the packet will be logged by the session.
|
||||
(&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
||||
}
|
||||
|
||||
|
@ -535,9 +543,7 @@ func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
||||
defer p.buffer.Release()
|
||||
hdr := p.hdr
|
||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) {
|
||||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
||||
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type serverSession struct {
|
||||
quicSession
|
||||
|
||||
config *Config
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandler = &serverSession{}
|
||||
|
||||
func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler {
|
||||
return &serverSession{
|
||||
quicSession: sess,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handlePacket(p *receivedPacket) {
|
||||
if err := s.handlePacketImpl(p); err != nil {
|
||||
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
|
||||
hdr := p.hdr
|
||||
|
||||
// Probably an old packet that was sent by the client before the version was negotiated.
|
||||
// It is safe to drop it.
|
||||
if hdr.IsLongHeader && hdr.Version != s.quicSession.GetVersion() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hdr.IsLongHeader {
|
||||
switch hdr.Type {
|
||||
case protocol.PacketTypeInitial, protocol.PacketTypeHandshake:
|
||||
// nothing to do here. Packet will be passed to the session.
|
||||
default:
|
||||
// Note that this also drops 0-RTT packets.
|
||||
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
|
||||
}
|
||||
}
|
||||
|
||||
s.quicSession.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverSession) GetPerspective() protocol.Perspective {
|
||||
return protocol.PerspectiveServer
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Server Session", func() {
|
||||
var (
|
||||
qsess *MockQuicSession
|
||||
sess *serverSession
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
qsess = NewMockQuicSession(mockCtrl)
|
||||
sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession)
|
||||
})
|
||||
|
||||
It("handles packets", func() {
|
||||
p := &receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5},
|
||||
},
|
||||
}
|
||||
qsess.EXPECT().handlePacket(p)
|
||||
sess.handlePacket(p)
|
||||
})
|
||||
|
||||
It("ignores delayed packets with mismatching versions", func() {
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
// don't EXPECT any calls to handlePacket()
|
||||
p := &receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Version: protocol.VersionNumber(123),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
err := sess.handlePacketImpl(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("ignores packets with the wrong Long Header type", func() {
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
p := &receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Version: protocol.VersionNumber(100),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
err := sess.handlePacketImpl(p)
|
||||
Expect(err).To(MatchError("Received unsupported packet type: Retry"))
|
||||
})
|
||||
|
||||
It("passes on Handshake packets", func() {
|
||||
p := &receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
Version: protocol.VersionNumber(100),
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
},
|
||||
}
|
||||
qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100))
|
||||
qsess.EXPECT().handlePacket(p)
|
||||
Expect(sess.handlePacketImpl(p)).To(Succeed())
|
||||
})
|
||||
|
||||
It("has the right perspective", func() {
|
||||
Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer))
|
||||
})
|
||||
})
|
162
server_test.go
162
server_test.go
|
@ -26,6 +26,18 @@ var _ = Describe("Server", func() {
|
|||
tlsConf *tls.Config
|
||||
)
|
||||
|
||||
getPacket := func(hdr *wire.Header, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: *hdr,
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
conn = newMockPacketConn()
|
||||
conn.addr = &net.UDPAddr{}
|
||||
|
@ -124,53 +136,45 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
It("drops Initial packets with a too short connection ID", func() {
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
}))
|
||||
serv.handlePacket(getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Version: serv.config.Versions[0],
|
||||
}, nil))
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
|
||||
It("drops too small Initial", func() {
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
|
||||
}))
|
||||
serv.handlePacket(getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize-100),
|
||||
))
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
|
||||
It("drops packets with a too short connection ID", func() {
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}))
|
||||
serv.handlePacket(getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize)))
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
|
||||
It("drops non-Initial packets", func() {
|
||||
serv.logger.SetLogLevel(utils.LogLevelDebug)
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
serv.handlePacket(getPacket(
|
||||
&wire.Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: []byte("invalid"),
|
||||
}))
|
||||
[]byte("invalid"),
|
||||
))
|
||||
})
|
||||
|
||||
It("decodes the cookie from the Token field", func() {
|
||||
|
@ -187,15 +191,14 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
token, err := serv.cookieGenerator.NewToken(raddr, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: raddr,
|
||||
hdr: &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: token,
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}))
|
||||
packet := getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: token,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -211,31 +214,29 @@ var _ = Describe("Server", func() {
|
|||
close(done)
|
||||
return false
|
||||
}
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: raddr,
|
||||
hdr: &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: []byte("foobar"),
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}))
|
||||
packet := getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
Token: []byte("foobar"),
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = raddr
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
||||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: srcConnID,
|
||||
DestConnectionID: destConnID,
|
||||
Version: 0x42,
|
||||
},
|
||||
}))
|
||||
packet := getPacket(&wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: srcConnID,
|
||||
DestConnectionID: destConnID,
|
||||
Version: 0x42,
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
serv.handlePacket(packet)
|
||||
var write mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||
|
@ -249,16 +250,15 @@ var _ = Describe("Server", func() {
|
|||
It("replies with a Retry packet, if a Cookie is required", func() {
|
||||
serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return false }
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
serv.handleInitial(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||
hdr: hdr,
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}))
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
serv.handlePacket(packet)
|
||||
var write mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||
|
@ -273,15 +273,13 @@ var _ = Describe("Server", func() {
|
|||
It("creates a session, if no Cookie is required", func() {
|
||||
serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true }
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := &receivedPacket{
|
||||
hdr: hdr,
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
run := make(chan struct{})
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
|
@ -309,7 +307,7 @@ var _ = Describe("Server", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.handlePacket(insertPacketBuffer(p))
|
||||
serv.handlePacket(p)
|
||||
// the Handshake packet is written by the session
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
close(done)
|
||||
|
@ -324,16 +322,14 @@ var _ = Describe("Server", func() {
|
|||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := &receivedPacket{
|
||||
remoteAddr: senderAddr,
|
||||
hdr: hdr,
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
|
@ -360,12 +356,12 @@ var _ = Describe("Server", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
serv.handlePacket(insertPacketBuffer(p))
|
||||
serv.handlePacket(p)
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
serv.handlePacket(insertPacketBuffer(p))
|
||||
serv.handlePacket(p)
|
||||
var reject mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&reject))
|
||||
Expect(reject.to).To(Equal(senderAddr))
|
||||
|
@ -381,16 +377,14 @@ var _ = Describe("Server", func() {
|
|||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := &receivedPacket{
|
||||
remoteAddr: senderAddr,
|
||||
hdr: hdr,
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionCreated := make(chan struct{})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -414,7 +408,7 @@ var _ = Describe("Server", func() {
|
|||
return sess, nil
|
||||
}
|
||||
|
||||
serv.handlePacket(insertPacketBuffer(p))
|
||||
serv.handlePacket(p)
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
Eventually(sessionCreated).Should(BeClosed())
|
||||
cancel()
|
||||
|
@ -429,7 +423,7 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sess.EXPECT().Close()
|
||||
sess.EXPECT().getPerspective()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
|
60
session.go
60
session.go
|
@ -55,7 +55,6 @@ type cryptoStreamHandler interface {
|
|||
|
||||
type receivedPacket struct {
|
||||
remoteAddr net.Addr
|
||||
hdr *wire.Header
|
||||
rcvTime time.Time
|
||||
data []byte
|
||||
|
||||
|
@ -483,7 +482,43 @@ func (s *session) handleHandshakeComplete() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
|
||||
func (s *session) handlePacketImpl(p *receivedPacket) bool {
|
||||
var counter uint8
|
||||
var lastConnID protocol.ConnectionID
|
||||
var processed bool
|
||||
for len(p.data) > 0 {
|
||||
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnID.Len())
|
||||
if err != nil {
|
||||
s.logger.Debugf("error parsing packet: %s", err)
|
||||
break
|
||||
}
|
||||
|
||||
if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
|
||||
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
|
||||
break
|
||||
}
|
||||
lastConnID = hdr.DestConnectionID
|
||||
|
||||
if counter > 0 {
|
||||
p.buffer.Split()
|
||||
}
|
||||
counter++
|
||||
|
||||
// only log if this actually a coalesced packet
|
||||
if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
|
||||
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
|
||||
}
|
||||
p.data = packetData
|
||||
pr := s.handleSinglePacket(p, hdr)
|
||||
if pr {
|
||||
processed = pr
|
||||
}
|
||||
p.data = rest
|
||||
}
|
||||
return processed
|
||||
}
|
||||
|
||||
func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
|
||||
var wasQueued bool
|
||||
|
||||
defer func() {
|
||||
|
@ -493,22 +528,22 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
|
|||
}
|
||||
}()
|
||||
|
||||
if p.hdr.Type == protocol.PacketTypeRetry {
|
||||
return s.handleRetryPacket(p)
|
||||
if hdr.Type == protocol.PacketTypeRetry {
|
||||
return s.handleRetryPacket(p, hdr)
|
||||
}
|
||||
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
// After this, all packets with a different source connection have to be ignored.
|
||||
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
|
||||
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID)
|
||||
return false
|
||||
}
|
||||
// drop 0-RTT packets
|
||||
if p.hdr.Type == protocol.PacketType0RTT {
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
return false
|
||||
}
|
||||
|
||||
packet, err := s.unpacker.Unpack(p.hdr, p.data)
|
||||
packet, err := s.unpacker.Unpack(hdr, p.data)
|
||||
if err != nil {
|
||||
if err == handshake.ErrOpenerNotYetAvailable {
|
||||
// Sealer for this encryption level not yet available.
|
||||
|
@ -524,7 +559,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
|
|||
}
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel)
|
||||
s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), hdr.DestConnectionID, packet.encryptionLevel)
|
||||
packet.hdr.Log(s.logger)
|
||||
}
|
||||
|
||||
|
@ -535,7 +570,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
|
|||
return true
|
||||
}
|
||||
|
||||
func (s *session) handleRetryPacket(p *receivedPacket) bool /* was this a valid Retry */ {
|
||||
func (s *session) handleRetryPacket(p *receivedPacket, hdr *wire.Header) bool /* was this a valid Retry */ {
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
s.logger.Debugf("Ignoring Retry.")
|
||||
return false
|
||||
|
@ -544,7 +579,6 @@ func (s *session) handleRetryPacket(p *receivedPacket) bool /* was this a valid
|
|||
s.logger.Debugf("Ignoring Retry, since we already received a packet.")
|
||||
return false
|
||||
}
|
||||
hdr := p.hdr
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
if !hdr.OrigDestConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID)
|
||||
|
@ -1246,6 +1280,10 @@ func (s *session) RemoteAddr() net.Addr {
|
|||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *session) getPerspective() protocol.Perspective {
|
||||
return s.perspective
|
||||
}
|
||||
|
||||
func (s *session) GetVersion() protocol.VersionNumber {
|
||||
return s.version
|
||||
}
|
||||
|
|
293
session_test.go
293
session_test.go
|
@ -3,6 +3,7 @@ package quic
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime/pprof"
|
||||
|
@ -354,20 +355,6 @@ var _ = Describe("Session", func() {
|
|||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
It("drops Retry packets", func() {
|
||||
hdr := wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
(&wire.ExtendedHeader{Header: hdr}).Write(buf, sess.version)
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
hdr: &hdr,
|
||||
data: buf.Bytes(),
|
||||
buffer: getPacketBuffer(),
|
||||
})).To(BeFalse())
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
var (
|
||||
runErr error
|
||||
|
@ -492,18 +479,26 @@ var _ = Describe("Session", func() {
|
|||
sess.unpacker = unpacker
|
||||
})
|
||||
|
||||
getData := func(extHdr *wire.ExtendedHeader) []byte {
|
||||
getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(extHdr.Write(buf, sess.version)).To(Succeed())
|
||||
// need to set extHdr.Header, since the wire.Header contains the parsed length
|
||||
hdr, _, _, err := wire.ParsePacket(buf.Bytes(), 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
extHdr.Header = *hdr
|
||||
return buf.Bytes()
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
It("drops Retry packets", func() {
|
||||
hdr := wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
}
|
||||
Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{Header: hdr}, nil))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("informs the ReceivedPacketHandler about non-retransmittable packets", func() {
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: sess.srcConnID},
|
||||
PacketNumber: 0x37,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
|
@ -517,15 +512,14 @@ var _ = Describe("Session", func() {
|
|||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false)
|
||||
sess.receivedPacketHandler = rph
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
rcvTime: rcvTime,
|
||||
hdr: &hdr.Header,
|
||||
data: getData(hdr),
|
||||
}))).To(BeTrue())
|
||||
packet := getPacket(hdr, nil)
|
||||
packet.rcvTime = rcvTime
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("informs the ReceivedPacketHandler about retransmittable packets", func() {
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: sess.srcConnID},
|
||||
PacketNumber: 0x37,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
|
@ -541,11 +535,9 @@ var _ = Describe("Session", func() {
|
|||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionHandshake, rcvTime, true)
|
||||
sess.receivedPacketHandler = rph
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
rcvTime: rcvTime,
|
||||
hdr: &hdr.Header,
|
||||
data: getData(hdr),
|
||||
}))).To(BeTrue())
|
||||
packet := getPacket(hdr, nil)
|
||||
packet.rcvTime = rcvTime
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("drops a packet when unpacking fails", func() {
|
||||
|
@ -559,10 +551,10 @@ var _ = Describe("Session", func() {
|
|||
sess.run()
|
||||
}()
|
||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||
sess.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}))
|
||||
sess.handlePacket(getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: sess.srcConnID},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}, nil))
|
||||
Consistently(sess.Context().Done()).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
sess.closeLocal(errors.New("close"))
|
||||
|
@ -586,65 +578,61 @@ var _ = Describe("Session", func() {
|
|||
close(done)
|
||||
}()
|
||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||
sess.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}))
|
||||
sess.handlePacket(getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: sess.srcConnID},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}, nil))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("handles duplicate packets", func() {
|
||||
hdr := &wire.ExtendedHeader{
|
||||
PacketNumber: 5,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
||||
encryptionLevel: protocol.Encryption1RTT,
|
||||
hdr: hdr,
|
||||
data: []byte{0}, // one PADDING frame
|
||||
}, nil).Times(2)
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores 0-RTT packets", func() {
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: sess.srcConnID,
|
||||
},
|
||||
}))).To(BeFalse())
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
Expect(sess.handlePacketImpl(getPacket(hdr, nil))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores packets with a different source connection ID", func() {
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: sess.destConnID,
|
||||
SrcConnectionID: sess.srcConnID,
|
||||
Length: 1,
|
||||
hdr1 := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: sess.destConnID,
|
||||
SrcConnectionID: sess.srcConnID,
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
PacketNumber: 1,
|
||||
}
|
||||
hdr2 := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: sess.destConnID,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
PacketNumber: 2,
|
||||
}
|
||||
Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
|
||||
// Send one packet, which might change the connection ID.
|
||||
// only EXPECT one call to the unpacker
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
||||
encryptionLevel: protocol.Encryption1RTT,
|
||||
hdr: &wire.ExtendedHeader{Header: *hdr},
|
||||
hdr: hdr1,
|
||||
data: []byte{0}, // one PADDING frame
|
||||
}, nil)
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: hdr,
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}))).To(BeTrue())
|
||||
Expect(sess.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue())
|
||||
// The next packet has to be ignored, since the source connection ID doesn't match.
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: sess.destConnID,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
Length: 1,
|
||||
},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}))).To(BeFalse())
|
||||
Expect(sess.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse())
|
||||
})
|
||||
|
||||
Context("updating the remote address", func() {
|
||||
|
@ -657,14 +645,86 @@ var _ = Describe("Session", func() {
|
|||
origAddr := sess.conn.(*mockConnection).remoteAddr
|
||||
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
|
||||
Expect(origAddr).ToNot(Equal(remoteIP))
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: remoteIP,
|
||||
hdr: &wire.Header{},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}))).To(BeTrue())
|
||||
packet := getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: sess.srcConnID},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}, nil)
|
||||
packet.remoteAddr = remoteIP
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr))
|
||||
})
|
||||
})
|
||||
|
||||
Context("coalesced packets", func() {
|
||||
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) {
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: sess.destConnID,
|
||||
Version: protocol.VersionTLS,
|
||||
Length: length,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}
|
||||
hdrLen := hdr.GetLength(sess.version)
|
||||
b := make([]byte, 1)
|
||||
rand.Read(b)
|
||||
packet := getPacket(hdr, bytes.Repeat(b, int(length)-3))
|
||||
return int(hdrLen), packet
|
||||
}
|
||||
|
||||
It("cuts packets to the right length", func() {
|
||||
hdrLen, packet := getPacketWithLength(sess.srcConnID, 456)
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
Expect(data).To(HaveLen(int(hdrLen + 456 - 3)))
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
}, nil
|
||||
})
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("handles coalesced packets", func() {
|
||||
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
Expect(data).To(HaveLen(int(hdrLen1 + 456 - 3)))
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
}, nil
|
||||
})
|
||||
hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123)
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
Expect(data).To(HaveLen(int(hdrLen2 + 123 - 3)))
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
}, nil
|
||||
})
|
||||
packet1.data = append(packet1.data, packet2.data...)
|
||||
Expect(sess.handlePacketImpl(packet1)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores coalesced packet parts if the destination connection IDs don't match", func() {
|
||||
wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
|
||||
Expect(sess.srcConnID).ToNot(Equal(wrongConnID))
|
||||
hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456)
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
Expect(data).To(HaveLen(int(hdrLen1 + 456 - 3)))
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
}, nil
|
||||
})
|
||||
_, packet2 := getPacketWithLength(wrongConnID, 123)
|
||||
// don't EXPECT any calls to unpacker.Unpack()
|
||||
packet1.data = append(packet1.data, packet2.data...)
|
||||
Expect(sess.handlePacketImpl(packet1)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("sending packets", func() {
|
||||
|
@ -1436,6 +1496,15 @@ var _ = Describe("Client Session", func() {
|
|||
cryptoSetup *mocks.MockCryptoSetup
|
||||
)
|
||||
|
||||
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(hdr.Write(buf, sess.version)).To(Succeed())
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
|
||||
|
@ -1450,9 +1519,9 @@ var _ = Describe("Client Session", func() {
|
|||
nil, // tls.Config
|
||||
42, // initial packet number
|
||||
&handshake.TransportParameters{},
|
||||
protocol.VersionWhatever,
|
||||
protocol.VersionTLS,
|
||||
utils.DefaultLogger,
|
||||
protocol.VersionWhatever,
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
sess = sessP.(*session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -1479,16 +1548,16 @@ var _ = Describe("Client Session", func() {
|
|||
}()
|
||||
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
|
||||
packer.EXPECT().ChangeDestConnectionID(newConnID)
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
SrcConnectionID: newConnID,
|
||||
DestConnectionID: sess.srcConnID,
|
||||
Length: 1,
|
||||
},
|
||||
data: []byte{0},
|
||||
}))).To(BeTrue())
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}, []byte{0}))).To(BeTrue())
|
||||
// make sure the go routine returns
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||
|
@ -1498,56 +1567,52 @@ var _ = Describe("Client Session", func() {
|
|||
})
|
||||
|
||||
Context("handling Retry", func() {
|
||||
var validRetryHdr *wire.Header
|
||||
var validRetryHdr *wire.ExtendedHeader
|
||||
|
||||
BeforeEach(func() {
|
||||
validRetryHdr = &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
OrigDestConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
Token: []byte("foobar"),
|
||||
validRetryHdr = &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
OrigDestConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
Token: []byte("foobar"),
|
||||
Version: sess.version,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
getPacket := func(hdr *wire.Header) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Write(buf, sess.version)
|
||||
return &receivedPacket{
|
||||
hdr: hdr,
|
||||
data: buf.Bytes(),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
It("handles Retry packets", func() {
|
||||
cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
|
||||
packer.EXPECT().SetToken([]byte("foobar"))
|
||||
packer.EXPECT().ChangeDestConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeTrue())
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores Retry packets after receiving a regular packet", func() {
|
||||
sess.receivedFirstPacket = true
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeFalse())
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores Retry packets if the server didn't change the connection ID", func() {
|
||||
validRetryHdr.SrcConnectionID = sess.destConnID
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeFalse())
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores Retry packets with the wrong original destination connection ID", func() {
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Token: []byte("foobar"),
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Token: []byte("foobar"),
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}
|
||||
Expect(sess.handlePacketImpl(getPacket(hdr))).To(BeFalse())
|
||||
Expect(sess.handlePacketImpl(getPacket(hdr, nil))).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue