diff --git a/client_test.go b/client_test.go index 8c4cfcbd..0bc1e2ef 100644 --- a/client_test.go +++ b/client_test.go @@ -601,6 +601,8 @@ var _ = Describe("Client", func() { Eventually(cl.versionNegotiated.Get).Should(BeTrue()) }) + // Illustrates that adversary that injects a version negotiation packet + // with no supported versions can break a connection. It("errors if no matching version is found", func() { sess := NewMockQuicSession(mockCtrl) done := make(chan struct{}) @@ -663,4 +665,34 @@ var _ = Describe("Client", func() { Expect(cl.version).ToNot(BeZero()) Expect(cl.GetVersion()).To(Equal(cl.version)) }) + + Context("handling potentially injected packets", func() { + // NOTE: We hope these tests as written will fail once mitigations for injection adversaries are put in place. + + // Illustrates that adversary who injects any packet quickly can + // cause a real version negotiation packet to be ignored. + It("version negotiation packets ignored if any other packet is received", func() { + // Copy of existing test "recognizes that a non Version Negotiation packet means that the server accepted the suggested version" + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) + cl.session = sess + cl.config = &Config{} + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: wire.Header{ + DestConnectionID: connID, + SrcConnectionID: connID, + Version: cl.version, + }, + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + cl.handlePacket(&receivedPacket{data: buf.Bytes()}) + + // Version negotiation is now ignored + cl.config = &Config{} + ver := cl.version + cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})) + Expect(cl.version).To(Equal(ver)) + }) + }) }) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index c31b8e9c..0290fd3e 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -15,6 +15,7 @@ import ( quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/integrationtests/tools/testserver" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testutils" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -72,53 +73,100 @@ var _ = Describe("MITM test", func() { Expect(err).ToNot(HaveOccurred()) }) - AfterEach(func() { - Eventually(serverSess.Context().Done()).Should(BeClosed()) - // Test shutdown is tricky due to the proxy. Just wait for a bit. - time.Sleep(50 * time.Millisecond) - Expect(clientConn.Close()).To(Succeed()) - Expect(serverConn.Close()).To(Succeed()) - Expect(proxy.Close()).To(Succeed()) - }) + Context("unsuccessful attacks", func() { + AfterEach(func() { + Eventually(serverSess.Context().Done()).Should(BeClosed()) + // Test shutdown is tricky due to the proxy. Just wait for a bit. + time.Sleep(50 * time.Millisecond) + Expect(clientConn.Close()).To(Succeed()) + Expect(serverConn.Close()).To(Succeed()) + Expect(proxy.Close()).To(Succeed()) + }) - Context("injecting invalid packets", func() { - const rtt = 20 * time.Millisecond + Context("injecting invalid packets", func() { + const rtt = 20 * time.Millisecond - sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) { - defer GinkgoRecover() - hdr, _, _, err := wire.ParsePacket(raw, connIDLen) - Expect(err).ToNot(HaveOccurred()) - replyHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: hdr.IsLongHeader, - DestConnectionID: hdr.DestConnectionID, - SrcConnectionID: hdr.SrcConnectionID, - Type: hdr.Type, - Version: hdr.Version, - }, - PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)), - PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1), - } - - const numPackets = 10 - ticker := time.NewTicker(rtt / numPackets) - for i := 0; i < numPackets; i++ { - payloadLen := mrand.Int31n(100) - replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) - buf := &bytes.Buffer{} - Expect(replyHdr.Write(buf, v)).To(Succeed()) - b := make([]byte, payloadLen) - mrand.Read(b) - buf.Write(b) - if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { - return + sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) { + defer GinkgoRecover() + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) + Expect(err).ToNot(HaveOccurred()) + replyHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: hdr.IsLongHeader, + DestConnectionID: hdr.DestConnectionID, + SrcConnectionID: hdr.SrcConnectionID, + Type: hdr.Type, + Version: hdr.Version, + }, + PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)), + PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1), } - <-ticker.C - } - } - runTest := func(delayCb quicproxy.DelayCallback) { - startServerAndProxy(delayCb, nil) + const numPackets = 10 + ticker := time.NewTicker(rtt / numPackets) + for i := 0; i < numPackets; i++ { + payloadLen := mrand.Int31n(100) + replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) + buf := &bytes.Buffer{} + Expect(replyHdr.Write(buf, v)).To(Succeed()) + b := make([]byte, payloadLen) + mrand.Read(b) + buf.Write(b) + if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { + return + } + <-ticker.C + } + } + + runTest := func(delayCb quicproxy.DelayCallback) { + startServerAndProxy(delayCb, nil) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + sess, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + &quic.Config{ + Versions: []protocol.VersionNumber{version}, + ConnectionIDLength: connIDLen, + }, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(testserver.PRData)) + Expect(sess.Close()).To(Succeed()) + } + + It("downloads a message when the packets are injected towards the server", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + go sendRandomPacketsOfSameType(clientConn, serverConn.LocalAddr(), raw) + } + return rtt / 2 + } + runTest(delayCb) + }) + + It("downloads a message when the packets are injected towards the client", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionOutgoing { + defer GinkgoRecover() + go sendRandomPacketsOfSameType(serverConn, clientConn.LocalAddr(), raw) + } + return rtt / 2 + } + runTest(delayCb) + }) + }) + + runTest := func(dropCb quicproxy.DropCallback) { + startServerAndProxy(nil, dropCb) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) sess, err := quic.Dial( @@ -140,130 +188,259 @@ var _ = Describe("MITM test", func() { Expect(sess.Close()).To(Succeed()) } - It("downloads a message when the packets are injected towards the server", func() { - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionIncoming { + Context("duplicating packets", func() { + It("downloads a message when packets are duplicated towards the server", func() { + dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() - go sendRandomPacketsOfSameType(clientConn, serverConn.LocalAddr(), raw) + if dir == quicproxy.DirectionIncoming { + _, err := clientConn.WriteTo(raw, serverConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return false } - return rtt / 2 - } - runTest(delayCb) + runTest(dropCb) + }) + + It("downloads a message when packets are duplicated towards the client", func() { + dropCb := func(dir quicproxy.Direction, raw []byte) bool { + defer GinkgoRecover() + if dir == quicproxy.DirectionOutgoing { + _, err := serverConn.WriteTo(raw, clientConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return false + } + runTest(dropCb) + }) }) - It("downloads a message when the packets are injected towards the client", func() { - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionOutgoing { + Context("corrupting packets", func() { + const interval = 10 // corrupt every 10th packet (stochastically) + const idleTimeout = time.Second + + var numCorrupted int32 + + BeforeEach(func() { + numCorrupted = 0 + serverConfig.IdleTimeout = idleTimeout + }) + + AfterEach(func() { + num := atomic.LoadInt32(&numCorrupted) + fmt.Fprintf(GinkgoWriter, "Corrupted %d packets.", num) + Expect(num).To(BeNumerically(">=", 1)) + // If the packet containing the CONNECTION_CLOSE is corrupted, + // we have to wait for the session to time out. + Eventually(serverSess.Context().Done(), 3*idleTimeout).Should(BeClosed()) + }) + + It("downloads a message when packet are corrupted towards the server", func() { + dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() - go sendRandomPacketsOfSameType(serverConn, clientConn.LocalAddr(), raw) + if dir == quicproxy.DirectionIncoming && mrand.Intn(interval) == 0 { + pos := mrand.Intn(len(raw)) + raw[pos] = byte(mrand.Intn(256)) + _, err := clientConn.WriteTo(raw, serverConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + atomic.AddInt32(&numCorrupted, 1) + return true + } + return false } - return rtt / 2 - } - runTest(delayCb) + runTest(dropCb) + }) + + It("downloads a message when packet are corrupted towards the client", func() { + dropCb := func(dir quicproxy.Direction, raw []byte) bool { + defer GinkgoRecover() + isRetry := raw[0]&0xc0 == 0xc0 // don't corrupt Retry packets + if dir == quicproxy.DirectionOutgoing && mrand.Intn(interval) == 0 && !isRetry { + pos := mrand.Intn(len(raw)) + raw[pos] = byte(mrand.Intn(256)) + _, err := serverConn.WriteTo(raw, clientConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + atomic.AddInt32(&numCorrupted, 1) + return true + } + return false + } + runTest(dropCb) + }) }) }) - runTest := func(dropCb quicproxy.DropCallback) { - startServerAndProxy(nil, dropCb) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) - Expect(err).ToNot(HaveOccurred()) - sess, err := quic.Dial( - clientConn, - raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - &quic.Config{ - Versions: []protocol.VersionNumber{version}, - ConnectionIDLength: connIDLen, - }, - ) - Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(testserver.PRData)) - Expect(sess.Close()).To(Succeed()) - } + Context("successful injection attacks", func() { + // These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake + // finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply + // that arrives before the real reply can tear down the connection in multiple ways. - Context("duplicating packets", func() { - It("downloads a message when packets are duplicated towards the server", func() { - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionIncoming { - _, err := clientConn.WriteTo(raw, serverConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return false - } - runTest(dropCb) - }) - - It("downloads a message when packets are duplicated towards the client", func() { - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionOutgoing { - _, err := serverConn.WriteTo(raw, clientConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return false - } - runTest(dropCb) - }) - }) - - Context("corrupting packets", func() { - const interval = 10 // corrupt every 10th packet (stochastically) - const idleTimeout = time.Second - - var numCorrupted int32 - - BeforeEach(func() { - numCorrupted = 0 - serverConfig.IdleTimeout = idleTimeout - }) + const rtt = 20 * time.Millisecond + // AfterEach closes the proxy, but each function is responsible + // for closing client and server connections AfterEach(func() { - num := atomic.LoadInt32(&numCorrupted) - fmt.Fprintf(GinkgoWriter, "Corrupted %d packets.", num) - Expect(num).To(BeNumerically(">=", 1)) - // If the packet containing the CONNECTION_CLOSE is corrupted, - // we have to wait for the session to time out. - Eventually(serverSess.Context().Done(), 3*idleTimeout).Should(BeClosed()) + // Test shutdown is tricky due to the proxy. Just wait for a bit. + time.Sleep(50 * time.Millisecond) + Expect(proxy.Close()).To(Succeed()) }) - It("downloads a message when packet are corrupted towards the server", func() { - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionIncoming && mrand.Intn(interval) == 0 { - pos := mrand.Intn(len(raw)) - raw[pos] = byte(mrand.Intn(256)) - _, err := clientConn.WriteTo(raw, serverConn.LocalAddr()) + // sendForgedVersionNegotiationPacket sends a fake VN packet with no supported versions + // from serverConn to client's remoteAddr + // expects hdr from an Initial packet intercepted from client + sendForgedVersionNegotationPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { + defer GinkgoRecover() + + // Create fake version negotiation packet with no supported versions + versions := []protocol.VersionNumber{} + packet, _ := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + + // Send the packet + _, err := conn.WriteTo(packet, remoteAddr) + Expect(err).ToNot(HaveOccurred()) + } + + // sendForgedRetryPacket sends a fake Retry packet with a modified srcConnID + // from serverConn to client's remoteAddr + // expects hdr from an Initial packet intercepted from client + sendForgedRetryPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { + defer GinkgoRecover() + + var x byte = 0x12 + fakeSrcConnID := protocol.ConnectionID{x, x, x, x, x, x, x, x} + retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) + + _, err := conn.WriteTo(retryPacket, remoteAddr) + Expect(err).ToNot(HaveOccurred()) + } + + // Send a forged Initial packet with no frames to client + // expects hdr from an Initial packet intercepted from client + sendForgedInitialPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { + defer GinkgoRecover() + + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) + _, err := conn.WriteTo(initialPacket, remoteAddr) + Expect(err).ToNot(HaveOccurred()) + } + + // Send a forged Initial packet with ACK for random packet to client + // expects hdr from an Initial packet intercepted from client + sendForgedInitialPacketWithAck := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { + defer GinkgoRecover() + + // Fake Initial with ACK for packet 2 (unsent) + ackFrame := testutils.ComposeAckFrame(2, 2) + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ackFrame}) + _, err := conn.WriteTo(initialPacket, remoteAddr) + Expect(err).ToNot(HaveOccurred()) + } + + // runTestFail succeeds if an error occurs in dialing + // expects a proxy delay function that runs every time a packet is received + runTestFail := func(delayCb quicproxy.DelayCallback) { + startServerAndProxy(delayCb, nil) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + _, err = quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + &quic.Config{ + Versions: []protocol.VersionNumber{version}, + ConnectionIDLength: connIDLen, + }, + ) + Expect(err).To(HaveOccurred()) + } + + // fails immediately because client connection closes when it can't find compatible version + It("fails when a forged version negotiation packet is sent to client", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + fmt.Println() + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) - atomic.AddInt32(&numCorrupted, 1) - return true + + if hdr.Type != protocol.PacketTypeInitial { + return 0 + } + + go sendForgedVersionNegotationPacket(serverConn, clientConn.LocalAddr(), hdr) } - return false + return rtt / 2 } - runTest(dropCb) + runTestFail(delayCb) }) - It("downloads a message when packet are corrupted towards the client", func() { - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - isRetry := raw[0]&0xc0 == 0xc0 // don't corrupt Retry packets - if dir == quicproxy.DirectionOutgoing && mrand.Intn(interval) == 0 && !isRetry { - pos := mrand.Intn(len(raw)) - raw[pos] = byte(mrand.Intn(256)) - _, err := serverConn.WriteTo(raw, clientConn.LocalAddr()) + // times out, because client doesn't accept subsequent real retry packets from server + // as it has already accepted a retry. + // TODO: determine behavior when server does not send Retry packets + initialPacketIntercepted := false + It("fails when a forged retry packet with modified srcConnID is sent to client", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) - atomic.AddInt32(&numCorrupted, 1) - return true + + if hdr.Type != protocol.PacketTypeInitial { + return 0 + } + + initialPacketIntercepted = true + go sendForgedRetryPacket(serverConn, clientConn.LocalAddr(), hdr) } - return false + return rtt / 2 } - runTest(dropCb) + runTestFail(delayCb) }) + + // times out, because client doesn't accept real retry packets from server because + // it has already accepted an initial. + // TODO: determine behavior when server does not send Retry packets + It("fails when a forged initial packet is sent to client", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) + Expect(err).ToNot(HaveOccurred()) + + if hdr.Type != protocol.PacketTypeInitial { + return 0 + } + + go sendForgedInitialPacket(serverConn, clientConn.LocalAddr(), hdr) + } + return rtt + } + runTestFail(delayCb) + }) + + // client connection closes immediately on receiving ack for unsent packet + It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) + Expect(err).ToNot(HaveOccurred()) + + if hdr.Type != protocol.PacketTypeInitial { + return 0 + } + + go sendForgedInitialPacketWithAck(serverConn, clientConn.LocalAddr(), hdr) + } + return rtt + } + runTestFail(delayCb) + }) + }) }) } diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go new file mode 100644 index 00000000..c6b711fc --- /dev/null +++ b/internal/testutils/testutils.go @@ -0,0 +1,126 @@ +package testutils + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// Utilities for simulating packet injection and man-in-the-middle (MITM) attacker tests. +// Do not use for non-testing purposes. + +// CryptoFrameType uses same types as messageType in crypto_setup.go +type CryptoFrameType uint8 + +// writePacket returns a new raw packet with the specified header and payload +func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { + buf := &bytes.Buffer{} + hdr.Write(buf, protocol.VersionTLS) + return append(buf.Bytes(), data...) +} + +// packRawPayload returns a new raw payload containing given frames +func packRawPayload(version protocol.VersionNumber, frames []wire.Frame) []byte { + buf := new(bytes.Buffer) + for _, cf := range frames { + cf.Write(buf, version) + } + return buf.Bytes() +} + +// ComposeCryptoFrame returns a new empty crypto frame of the specified +// type padded to size bytes with zeroes +func ComposeCryptoFrame(cft CryptoFrameType, size int) *wire.CryptoFrame { + data := make([]byte, size) + data[0] = byte(cft) + return &wire.CryptoFrame{ + Offset: 0, + Data: data, + } +} + +// ComposeConnCloseFrame returns a new Connection Close frame with a generic error +func ComposeConnCloseFrame() *wire.ConnectionCloseFrame { + return &wire.ConnectionCloseFrame{ + IsApplicationError: true, + ErrorCode: 0, + ReasonPhrase: "mitm attacker", + } +} + +// ComposeAckFrame returns a new Ack Frame that acknowledges all packets between smallest and largest +func ComposeAckFrame(smallest protocol.PacketNumber, largest protocol.PacketNumber) *wire.AckFrame { + ackRange := wire.AckRange{ + Smallest: smallest, + Largest: largest, + } + return &wire.AckFrame{ + AckRanges: []wire.AckRange{ackRange}, + DelayTime: 0, + } +} + +// ComposeInitialPacket returns an Initial packet encrypted under key +// (the original destination connection ID) containing specified frames +func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte { + sealer, _, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer) + + // compose payload + var payload []byte + if len(frames) == 0 { + payload = make([]byte, protocol.MinInitialPacketSize) + } else { + payload = packRawPayload(version, frames) + } + + // compose Initial header + payloadSize := len(payload) + pnLength := protocol.PacketNumberLen4 + length := payloadSize + int(pnLength) + sealer.Overhead() + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Length: protocol.ByteCount(length), + Version: version, + }, + PacketNumberLen: pnLength, + PacketNumber: 0x0, + } + + raw := writePacket(hdr, payload) + + // encrypt payload and header + payloadOffset := len(raw) - payloadSize + var encrypted []byte + encrypted = sealer.Seal(encrypted, payload, hdr.PacketNumber, raw[:payloadOffset]) + hdrBytes := raw[0:payloadOffset] + encrypted = append(hdrBytes, encrypted...) + pnOffset := payloadOffset - int(pnLength) // packet number offset + sealer.EncryptHeader( + encrypted[payloadOffset:payloadOffset+16], // first 16 bytes of payload (sample) + &encrypted[0], // first byte of header + encrypted[pnOffset:payloadOffset], // packet number bytes + ) + return encrypted +} + +// ComposeRetryPacket returns a new raw Retry Packet +func ComposeRetryPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, origDestConnID protocol.ConnectionID, token []byte, version protocol.VersionNumber) []byte { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + OrigDestConnectionID: origDestConnID, + Token: token, + Version: version, + }, + } + return writePacket(hdr, nil) +} diff --git a/session_test.go b/session_test.go index b070520c..5f78951d 100644 --- a/session_test.go +++ b/session_test.go @@ -20,6 +20,7 @@ import ( mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/testutils" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -1681,4 +1682,102 @@ var _ = Describe("Client Session", func() { Expect(err).To(MatchError("expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad")) }) }) + + Context("handling potentially injected packets", func() { + var unpacker *MockUnpacker + + getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(extHdr.Write(buf, sess.version)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + + // Convert an already packed raw packet into a receivedPacket + wrapPacket := func(packet []byte) *receivedPacket { + return &receivedPacket{ + data: packet, + buffer: getPacketBuffer(), + } + } + + // Illustrates that attacker may inject an Initial packet with a different + // source connection ID, causing endpoint to ignore a subsequent real Initial packets. + It("ignores Initial packets with a different source connection ID", func() { + // Modified from test "ignores packets with a different source connection ID" + unpacker = NewMockUnpacker(mockCtrl) + sess.unpacker = unpacker + + hdr1 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: sess.destConnID, + SrcConnectionID: sess.srcConnID, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, + } + hdr2 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: sess.destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 2, + } + Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) + // Send one packet, which might change the connection ID. + packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1) + // only EXPECT one call to the unpacker + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ + encryptionLevel: protocol.EncryptionInitial, + hdr: hdr1, + data: []byte{0}, // one PADDING frame + }, nil) + Expect(sess.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) + // The next packet has to be ignored, since the source connection ID doesn't match. + Expect(sess.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) + }) + + // Illustrates that an injected Initial with an ACK frame for an unsent packet causes + // the connection to immediately break down + It("fails on Initial-level ACK for unsent packet", func() { + sessionRunner.EXPECT().Retire(gomock.Any()) + ackFrame := testutils.ComposeAckFrame(0, 0) + initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{ackFrame}) + Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) + }) + + // Illustrates that an injected Initial with a CONNECTION_CLOSE frame causes + // the connection to immediately break down + It("fails on Initial-level CONNECTION_CLOSE frame", func() { + sessionRunner.EXPECT().Remove(gomock.Any()) + connCloseFrame := testutils.ComposeConnCloseFrame() + initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{connCloseFrame}) + Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) + }) + + // Illustrates that attacker who injects a Retry packet and changes the connection ID + // can cause subsequent real Initial packets to be ignored + It("ignores Initial packets which use original source id, after accepting a Retry", func() { + newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) + packer.EXPECT().SetToken([]byte("foobar")) + packer.EXPECT().ChangeDestConnectionID(newSrcConnID) + + sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, sess.destConnID, sess.destConnID, []byte("foobar"), sess.version))) + initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, nil) + Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) + }) + + }) })