diff --git a/client.go b/client.go index d2c04bca..493da4b9 100644 --- a/client.go +++ b/client.go @@ -316,8 +316,8 @@ func (c *client) handlePacketImpl(p *receivedPacket) error { return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.hdr.DestConnectionID, c.srcConnID) } - if p.extHdr.Type == protocol.PacketTypeRetry { - c.handleRetryPacket(p.extHdr) + if p.hdr.Type == protocol.PacketTypeRetry { + c.handleRetryPacket(p.hdr) return nil } @@ -367,9 +367,9 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { return nil } -func (c *client) handleRetryPacket(hdr *wire.ExtendedHeader) { +func (c *client) handleRetryPacket(hdr *wire.Header) { c.logger.Debugf("<- Received Retry") - hdr.Log(c.logger) + (&wire.ExtendedHeader{Header: *hdr}).Log(c.logger) if !hdr.OrigDestConnectionID.Equal(c.destConnID) { c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID) return diff --git a/client_test.go b/client_test.go index ae91c7c8..8d865bb3 100644 --- a/client_test.go +++ b/client_test.go @@ -512,15 +512,13 @@ var _ = Describe("Client", func() { manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) { go handler.handlePacket(&receivedPacket{ hdr: &wire.Header{ - Version: cl.version, - DestConnectionID: id, - }, - extHdr: &wire.ExtendedHeader{Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, + Version: cl.version, Token: []byte("foobar"), OrigDestConnectionID: connID, - }}, + DestConnectionID: id, + }, }) }) manager.EXPECT().Add(gomock.Any(), gomock.Any()) @@ -575,17 +573,14 @@ var _ = Describe("Client", func() { manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) { go handler.handlePacket(&receivedPacket{ hdr: &wire.Header{ - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - DestConnectionID: id, - Version: cl.version, - }, - extHdr: &wire.ExtendedHeader{Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - Token: []byte("foobar"), + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: id, OrigDestConnectionID: connID, - Version: protocol.VersionTLS, - }}, + Token: []byte("foobar"), + Version: cl.version, + }, }) }).AnyTimes() manager.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes() @@ -691,10 +686,6 @@ var _ = Describe("Client", func() { SrcConnectionID: connID, Version: cl.version, }, - extHdr: &wire.ExtendedHeader{ - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - }, }) Expect(err).ToNot(HaveOccurred()) Expect(cl.versionNegotiated).To(BeTrue()) @@ -756,10 +747,6 @@ var _ = Describe("Client", func() { SrcConnectionID: connID, Version: cl.version, }, - extHdr: &wire.ExtendedHeader{ - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen1, - }, })).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID))) }) }) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index ced54741..21dd4b19 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -29,9 +29,6 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*Exte } func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { - if h.Type == protocol.PacketTypeRetry { - return h, nil - } pn, pnLen, err := utils.ReadVarIntPacketNumber(b) if err != nil { return nil, err diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index dfa884a9..4425bcb9 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -58,18 +58,6 @@ func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective)) } -// GetVersion mocks base method -func (m *MockPacketHandler) GetVersion() protocol.VersionNumber { - ret := m.ctrl.Call(m, "GetVersion") - ret0, _ := ret[0].(protocol.VersionNumber) - return ret0 -} - -// GetVersion indicates an expected call of GetVersion -func (mr *MockPacketHandlerMockRecorder) GetVersion() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockPacketHandler)(nil).GetVersion)) -} - // destroy mocks base method func (m *MockPacketHandler) destroy(arg0 error) { m.ctrl.Call(m, "destroy", arg0) diff --git a/packet_handler_map.go b/packet_handler_map.go index 31b4d8b3..1a283009 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -175,11 +175,9 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)] server := h.server - var version protocol.VersionNumber var handlePacket func(*receivedPacket) if handlerFound { // existing session handler := handlerEntry.handler - version = handler.GetVersion() handlePacket = handler.handlePacket } else { // no session found // this might be a stateless reset @@ -201,39 +199,13 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID) } handlePacket = server.handlePacket - version = hdr.Version } h.mutex.RUnlock() - var extHdr *wire.ExtendedHeader - var packetData []byte - if !hdr.IsVersionNegotiation() { - r = bytes.NewReader(data) - var err error - extHdr, err = hdr.ParseExtended(r, version) - if err != nil { - return fmt.Errorf("error parsing extended header: %s", err) - } - extHdr.Raw = data[:len(data)-r.Len()] - packetData = data[len(data)-r.Len():] - - if hdr.IsLongHeader { - if extHdr.Length < protocol.ByteCount(extHdr.PacketNumberLen) { - return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", extHdr.Length, extHdr.PacketNumberLen) - } - if protocol.ByteCount(len(packetData))+protocol.ByteCount(extHdr.PacketNumberLen) < extHdr.Length { - return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(extHdr.PacketNumberLen), extHdr.Length) - } - packetData = packetData[:int(extHdr.Length)-int(extHdr.PacketNumberLen)] - // TODO(#1312): implement parsing of compound packets - } - } - handlePacket(&receivedPacket{ remoteAddr: addr, hdr: hdr, - extHdr: extHdr, - data: packetData, + data: data, rcvTime: rcvTime, }) return nil diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index ae66c147..a955259b 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -59,15 +59,13 @@ var _ = Describe("Packet Handler Map", func() { handledPacket1 := make(chan struct{}) handledPacket2 := make(chan struct{}) packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.extHdr.DestConnectionID).To(Equal(connID1)) + Expect(p.hdr.DestConnectionID).To(Equal(connID1)) close(handledPacket1) }) - packetHandler1.EXPECT().GetVersion() packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.extHdr.DestConnectionID).To(Equal(connID2)) + Expect(p.hdr.DestConnectionID).To(Equal(connID2)) close(handledPacket2) }) - packetHandler2.EXPECT().GetVersion() handler.Add(connID1, packetHandler1) handler.Add(connID2, packetHandler2) @@ -109,7 +107,6 @@ var _ = Describe("Packet Handler Map", func() { handler.deleteRetiredSessionsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().GetVersion().Return(protocol.VersionWhatever) packetHandler.EXPECT().handlePacket(gomock.Any()) handler.Add(connID, packetHandler) handler.Retire(connID) @@ -123,74 +120,6 @@ var _ = Describe("Packet Handler Map", func() { Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) }) - It("errors on packets that are smaller than the length in the packet header", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().GetVersion().Return(protocol.VersionWhatever) - handler.Add(connID, packetHandler) - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 1000, - DestConnectionID: connID, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - buf := &bytes.Buffer{} - Expect(hdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) - buf.Write(bytes.Repeat([]byte{0}, 500-2 /* for packet number length */)) - - err := handler.handlePacket(nil, buf.Bytes()) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("errors when receiving a packet that has a length smaller than the packet number length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().GetVersion().Return(protocol.VersionWhatever) - handler.Add(connID, packetHandler) - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 3, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen4, - } - buf := &bytes.Buffer{} - Expect(hdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) - Expect(handler.handlePacket(nil, buf.Bytes())).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)")) - }) - - It("cuts packets to the right length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().GetVersion().Return(protocol.VersionWhatever) - handler.Add(connID, packetHandler) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.data).To(HaveLen(456 - int(p.extHdr.PacketNumberLen))) - }) - - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 456, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen1, - } - buf := &bytes.Buffer{} - Expect(hdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) - buf.Write(bytes.Repeat([]byte{0}, 500)) - Expect(handler.handlePacket(nil, buf.Bytes())).To(Succeed()) - }) - It("closes the packet handlers when reading from the conn fails", func() { done := make(chan struct{}) packetHandler := NewMockPacketHandler(mockCtrl) @@ -212,9 +141,8 @@ var _ = Describe("Packet Handler Map", func() { handler.AddWithResetToken(connID, packetHandler, token) // first send a normal packet handledPacket := make(chan struct{}) - packetHandler.EXPECT().GetVersion() packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.extHdr.DestConnectionID).To(Equal(connID)) + Expect(p.hdr.DestConnectionID).To(Equal(connID)) close(handledPacket) }) conn.dataToRead <- getPacket(connID) @@ -257,7 +185,7 @@ var _ = Describe("Packet Handler Map", func() { p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.extHdr.DestConnectionID).To(Equal(connID)) + Expect(p.hdr.DestConnectionID).To(Equal(connID)) }) handler.SetServer(server) Expect(handler.handlePacket(nil, p)).To(Succeed()) diff --git a/server.go b/server.go index 52beab58..bcbe041b 100644 --- a/server.go +++ b/server.go @@ -21,7 +21,6 @@ type packetHandler interface { handlePacket(*receivedPacket) io.Closer destroy(error) - GetVersion() protocol.VersionNumber GetPerspective() protocol.Perspective } @@ -306,7 +305,7 @@ func (s *server) handlePacket(p *receivedPacket) { } func (s *server) handlePacketImpl(p *receivedPacket) error { - hdr := p.extHdr + hdr := p.hdr // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { @@ -335,11 +334,11 @@ func (s *server) handleInitial(p *receivedPacket) { } func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) { - hdr := p.extHdr + hdr := p.hdr if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { return nil, nil, errors.New("dropping Initial packet with too short connection ID") } - if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize { + if len(p.data) < protocol.MinInitialPacketSize { return nil, nil, errors.New("dropping too small Initial packet") } @@ -358,7 +357,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con if !s.config.AcceptCookie(p.remoteAddr, cookie) { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the session. - p.extHdr.Log(s.logger) + (&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger) return nil, nil, s.sendRetry(p.remoteAddr, hdr) } @@ -422,7 +421,7 @@ func (s *server) createNewSession( return sess, nil } -func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.ExtendedHeader) error { +func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { token, err := s.cookieGenerator.NewToken(remoteAddr, hdr.DestConnectionID) if err != nil { return err @@ -452,7 +451,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.ExtendedHeader) error } func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { - hdr := p.extHdr + hdr := p.hdr s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) diff --git a/server_session.go b/server_session.go index f2c746a1..d1ab73a4 100644 --- a/server_session.go +++ b/server_session.go @@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) { } func (s *serverSession) handlePacketImpl(p *receivedPacket) error { - hdr := p.extHdr + hdr := p.hdr // Probably an old packet that was sent by the client before the version was negotiated. // It is safe to drop it. diff --git a/server_session_test.go b/server_session_test.go index 91aa9283..b350eb51 100644 --- a/server_session_test.go +++ b/server_session_test.go @@ -22,8 +22,8 @@ var _ = Describe("Server Session", func() { It("handles packets", func() { p := &receivedPacket{ - extHdr: &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}}, + hdr: &wire.Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, }, } qsess.EXPECT().handlePacket(p) @@ -34,12 +34,10 @@ var _ = Describe("Server Session", func() { qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) // don't EXPECT any calls to handlePacket() p := &receivedPacket{ - extHdr: &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Version: protocol.VersionNumber(123), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, + hdr: &wire.Header{ + IsLongHeader: true, + Version: protocol.VersionNumber(123), + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, }, } err := sess.handlePacketImpl(p) @@ -49,12 +47,12 @@ var _ = Describe("Server Session", func() { It("ignores packets with the wrong Long Header type", func() { qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) p := &receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, Version: protocol.VersionNumber(100), DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }}, + }, } err := sess.handlePacketImpl(p) Expect(err).To(MatchError("Received unsupported packet type: Retry")) @@ -62,12 +60,12 @@ var _ = Describe("Server Session", func() { It("passes on Handshake packets", func() { p := &receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, Version: protocol.VersionNumber(100), DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }}, + }, } qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) qsess.EXPECT().handlePacket(p) diff --git a/server_test.go b/server_test.go index 03a0e4f0..0643da39 100644 --- a/server_test.go +++ b/server_test.go @@ -106,43 +106,39 @@ var _ = Describe("Server", func() { It("drops Initial packets with a too short connection ID", func() { serv.handlePacket(&receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Version: serv.config.Versions[0], - }}, + }, }) Expect(conn.dataWritten.Len()).To(BeZero()) }) It("drops too small Initial", func() { serv.handlePacket(&receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, Version: serv.config.Versions[0], - }}, + }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), }) Consistently(conn.dataWritten.Len).Should(BeZero()) }) It("drops packets with a too short connection ID", func() { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{ + serv.handlePacket(&receivedPacket{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Version: serv.config.Versions[0], }, - PacketNumberLen: protocol.PacketNumberLen1, - } - serv.handlePacket(&receivedPacket{ - extHdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), + data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), }) Consistently(conn.dataWritten.Len).Should(BeZero()) }) @@ -150,10 +146,10 @@ var _ = Describe("Server", func() { It("drops non-Initial packets", func() { serv.logger.SetLogLevel(utils.LogLevelDebug) serv.handlePacket(&receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ Type: protocol.PacketTypeHandshake, Version: serv.config.Versions[0], - }}, + }, data: []byte("invalid"), }) }) @@ -174,11 +170,11 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) serv.handlePacket(&receivedPacket{ remoteAddr: raddr, - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ Type: protocol.PacketTypeInitial, Token: token, Version: serv.config.Versions[0], - }}, + }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), }) Eventually(done).Should(BeClosed()) @@ -198,11 +194,11 @@ var _ = Describe("Server", func() { } serv.handlePacket(&receivedPacket{ remoteAddr: raddr, - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ Type: protocol.PacketTypeInitial, Token: []byte("foobar"), Version: serv.config.Versions[0], - }}, + }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), }) Eventually(done).Should(BeClosed()) @@ -212,13 +208,13 @@ var _ = Describe("Server", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} serv.handlePacket(&receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: srcConnID, DestConnectionID: destConnID, Version: 0x42, - }}, + }, }) Expect(conn.dataWritten.Len()).ToNot(BeZero()) hdr, err := wire.ParseHeader(bytes.NewReader(conn.dataWritten.Bytes()), 0) @@ -231,15 +227,15 @@ var _ = Describe("Server", func() { It("replies with a Retry packet, if a Cookie is required", func() { serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return false } - hdr := &wire.ExtendedHeader{Header: wire.Header{ + hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, - }} + } serv.handleInitial(&receivedPacket{ remoteAddr: &net.UDPAddr{}, - extHdr: hdr, + hdr: hdr, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), }) Expect(conn.dataWritten.Len()).ToNot(BeZero()) @@ -253,15 +249,15 @@ var _ = Describe("Server", func() { It("creates a session, if no Cookie is required", func() { serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true } - hdr := &wire.ExtendedHeader{Header: wire.Header{ + hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, - }} + } p := &receivedPacket{ - extHdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), + hdr: hdr, + data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), } run := make(chan struct{}) serv.newSession = func( diff --git a/session.go b/session.go index 63541d0e..c3b1892d 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "context" "crypto/tls" "errors" @@ -53,7 +54,6 @@ type cryptoStreamHandler interface { type receivedPacket struct { remoteAddr net.Addr hdr *wire.Header - extHdr *wire.ExtendedHeader data []byte rcvTime time.Time } @@ -374,7 +374,7 @@ runLoop: } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. - putPacketBuffer(&p.extHdr.Raw) + // TODO: putPacketBuffer(&p.extHdr.Raw) case <-s.handshakeCompleteChan: s.handleHandshakeComplete() } @@ -479,14 +479,33 @@ func (s *session) handleHandshakeComplete() { } func (s *session) handlePacketImpl(p *receivedPacket) error { - hdr := p.extHdr // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.extHdr.SrcConnectionID, s.destConnID) + if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID) return nil } + data := p.data + r := bytes.NewReader(data) + hdr, err := p.hdr.ParseExtended(r, s.version) + if err != nil { + return fmt.Errorf("error parsing extended header: %s", err) + } + hdr.Raw = data[:len(data)-r.Len()] + data = data[len(data)-r.Len():] + + if hdr.IsLongHeader { + if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) { + return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen) + } + if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length { + return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length) + } + data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)] + // TODO(#1312): implement parsing of compound packets + } + p.rcvTime = time.Now() // Calculate packet number hdr.PacketNumber = protocol.InferPacketNumber( @@ -496,7 +515,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.version, ) - packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data) + packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) if s.logger.Debug() { if err != nil { s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID) diff --git a/session_test.go b/session_test.go index d16b6a3f..70d6aec9 100644 --- a/session_test.go +++ b/session_test.go @@ -444,45 +444,48 @@ var _ = Describe("Session", func() { }) Context("receiving packets", func() { - var hdr *wire.ExtendedHeader var unpacker *MockUnpacker BeforeEach(func() { unpacker = NewMockUnpacker(mockCtrl) sess.unpacker = unpacker - hdr = &wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4} }) - It("sets the largestRcvdPacketNumber", func() { - hdr.PacketNumber = 5 - hdr.Raw = []byte("raw header") - unpacker.EXPECT().Unpack([]byte("raw header"), hdr, []byte("foobar")).Return(&unpackedPacket{}, nil) - err := sess.handlePacketImpl(&receivedPacket{extHdr: hdr, data: []byte("foobar")}) + getData := func(extHdr *wire.ExtendedHeader) []byte { + buf := &bytes.Buffer{} + Expect(extHdr.Write(buf, sess.version)).To(Succeed()) + // need to set extHdr.Header, since the wire.Header contains the parsed length + hdr, err := wire.ParseHeader(bytes.NewReader(buf.Bytes()), 0) Expect(err).ToNot(HaveOccurred()) + extHdr.Header = *hdr + return buf.Bytes() + } + + It("sets the largestRcvdPacketNumber", func() { + hdr := &wire.ExtendedHeader{ + PacketNumber: 5, + PacketNumberLen: protocol.PacketNumberLen4, + } + hdrRaw := getData(hdr) + data := append(hdrRaw, []byte("foobar")...) + unpacker.EXPECT().Unpack(hdrRaw, gomock.Any(), []byte("foobar")).Return(&unpackedPacket{}, nil) + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(Succeed()) Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) }) It("informs the ReceivedPacketHandler", func() { + hdr := &wire.ExtendedHeader{ + Raw: []byte("raw header"), + PacketNumber: 5, + PacketNumberLen: protocol.PacketNumberLen4, + } unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(5), gomock.Any(), false).Do(func(_ protocol.PacketNumber, t time.Time, _ bool) { Expect(t).To(BeTemporally("~", time.Now(), scaleDuration(25*time.Millisecond))) }) sess.receivedPacketHandler = rph - hdr.PacketNumber = 5 - Expect(sess.handlePacketImpl(&receivedPacket{extHdr: hdr})).To(Succeed()) - }) - - It("doesn't inform the ReceivedPacketHandler about Retry packets", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) - now := time.Now().Add(time.Hour) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - sess.receivedPacketHandler = rph - // don't EXPECT any call to ReceivedPacket - hdr.PacketNumber = 5 - hdr.Type = protocol.PacketTypeRetry - err := sess.handlePacketImpl(&receivedPacket{extHdr: hdr, rcvTime: now}) - Expect(err).ToNot(HaveOccurred()) + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(Succeed()) }) It("closes when handling a packet fails", func() { @@ -491,7 +494,6 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - hdr.PacketNumber = 5 done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -501,42 +503,99 @@ var _ = Describe("Session", func() { close(done) }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(&receivedPacket{extHdr: hdr}) + sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})}) Eventually(done).Should(BeClosed()) }) It("handles duplicate packets", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2) - hdr.PacketNumber = 5 - Expect(sess.handlePacketImpl(&receivedPacket{extHdr: hdr})).To(Succeed()) - Expect(sess.handlePacketImpl(&receivedPacket{extHdr: hdr})).To(Succeed()) + hdr := &wire.ExtendedHeader{ + PacketNumber: 5, + PacketNumberLen: protocol.PacketNumberLen1, + } + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(Succeed()) + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(Succeed()) }) It("ignores packets with a different source connection ID", func() { // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) - err := sess.handlePacketImpl(&receivedPacket{ - extHdr: &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: sess.destConnID, - SrcConnectionID: sess.srcConnID, - }, + Expect(sess.handlePacketImpl(&receivedPacket{ + hdr: &wire.Header{ + IsLongHeader: true, + DestConnectionID: sess.destConnID, + SrcConnectionID: sess.srcConnID, + Length: 1, }, - }) - Expect(err).ToNot(HaveOccurred()) + data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), + })).To(Succeed()) // The next packet has to be ignored, since the source connection ID doesn't match. - err = sess.handlePacketImpl(&receivedPacket{ - extHdr: &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: sess.destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, + Expect(sess.handlePacketImpl(&receivedPacket{ + hdr: &wire.Header{ + IsLongHeader: true, + DestConnectionID: sess.destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + Length: 1, }, + data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), + })).To(Succeed()) + }) + + It("errors on packets that are smaller than the length in the packet header", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Length: 1000, + DestConnectionID: connID, + Version: protocol.VersionTLS, + }, + PacketNumberLen: protocol.PacketNumberLen2, + } + data := getData(hdr) + data = append(data, make([]byte, 500-2 /* for packet number length */)...) + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("errors when receiving a packet that has a length smaller than the packet number length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + DestConnectionID: connID, + Type: protocol.PacketTypeHandshake, + Length: 3, + Version: protocol.VersionTLS, + }, + PacketNumberLen: protocol.PacketNumberLen4, + } + data := getData(hdr) + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)")) + }) + + It("cuts packets to the right length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + pnLen := protocol.PacketNumberLen2 + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + DestConnectionID: connID, + Type: protocol.PacketTypeHandshake, + Length: 456, + Version: protocol.VersionTLS, + }, + PacketNumberLen: pnLen, + } + payloadLen := 456 - int(pnLen) + data := getData(hdr) + data = append(data, make([]byte, payloadLen)...) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ []byte, _ *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(payloadLen)) + return &unpackedPacket{}, nil }) - Expect(err).ToNot(HaveOccurred()) + Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: data})).To(Succeed()) }) Context("updating the remote address", func() { @@ -547,7 +606,8 @@ var _ = Describe("Session", func() { Expect(origAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ remoteAddr: remoteIP, - extHdr: &wire.ExtendedHeader{PacketNumber: 1337}, + hdr: &wire.Header{}, + data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), } err := sess.handlePacketImpl(&p) Expect(err).ToNot(HaveOccurred()) @@ -1320,16 +1380,16 @@ var _ = Describe("Client Session", func() { }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} packer.EXPECT().ChangeDestConnectionID(newConnID) - err := sess.handlePacketImpl(&receivedPacket{ - extHdr: &wire.ExtendedHeader{Header: wire.Header{ + Expect(sess.handlePacketImpl(&receivedPacket{ + hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, SrcConnectionID: newConnID, DestConnectionID: sess.srcConnID, - }}, + Length: 1, + }, data: []byte{0}, - }) - Expect(err).ToNot(HaveOccurred()) + })).To(Succeed()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().retireConnectionID(gomock.Any())