reassemble post-handshake TLS messages before passing them to crypto/tls (#4038)

This commit is contained in:
Marten Seemann 2023-08-19 07:16:57 +07:00 committed by GitHub
parent 501cc21c4b
commit 5c5db8cc59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 7 deletions

View file

@ -243,7 +243,7 @@ var newConnection = func(
handshakeDestConnID: destConnID, handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(), srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(), oneRTTStream: newCryptoStream(true),
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
tracer: tracer, tracer: tracer,
logger: logger, logger: logger,
@ -391,7 +391,7 @@ var newClientConnection = func(
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream() oneRTTStream := newCryptoStream(true)
params := &wire.TransportParameters{ params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -447,8 +447,8 @@ var newClientConnection = func(
} }
func (s *connection) preSetup() { func (s *connection) preSetup() {
s.initialStream = newCryptoStream() s.initialStream = newCryptoStream(false)
s.handshakeStream = newCryptoStream() s.handshakeStream = newCryptoStream(false)
s.sendQueue = newSendQueue(s.conn) s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue() s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)

View file

@ -30,10 +30,17 @@ type cryptoStreamImpl struct {
writeOffset protocol.ByteCount writeOffset protocol.ByteCount
writeBuf []byte writeBuf []byte
// Reassemble TLS handshake messages before returning them from GetCryptoData.
// This is only needed because crypto/tls doesn't correctly handle post-handshake messages.
onlyCompleteMsg bool
} }
func newCryptoStream() cryptoStream { func newCryptoStream(onlyCompleteMsg bool) cryptoStream {
return &cryptoStreamImpl{queue: newFrameSorter()} return &cryptoStreamImpl{
queue: newFrameSorter(),
onlyCompleteMsg: onlyCompleteMsg,
}
} }
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
@ -71,6 +78,20 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
// GetCryptoData retrieves data that was received in CRYPTO frames // GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte { func (s *cryptoStreamImpl) GetCryptoData() []byte {
if s.onlyCompleteMsg {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
}
b := s.msgBuf b := s.msgBuf
s.msgBuf = nil s.msgBuf = nil
return b return b

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"crypto/rand"
"fmt" "fmt"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -15,7 +16,7 @@ var _ = Describe("Crypto Stream", func() {
var str cryptoStream var str cryptoStream
BeforeEach(func() { BeforeEach(func() {
str = newCryptoStream() str = newCryptoStream(false)
}) })
Context("handling incoming data", func() { Context("handling incoming data", func() {
@ -137,4 +138,23 @@ var _ = Describe("Crypto Stream", func() {
Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.Data).To(Equal([]byte("bar")))
}) })
}) })
It("reassembles data", func() {
str = newCryptoStream(true)
data := make([]byte, 1337)
l := len(data) - 4
data[1] = uint8(l >> 16)
data[2] = uint8(l >> 8)
data[3] = uint8(l)
rand.Read(data[4:])
for i, b := range data {
Expect(str.GetCryptoData()).To(BeEmpty())
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: protocol.ByteCount(i),
Data: []byte{b},
})).To(Succeed())
}
Expect(str.GetCryptoData()).To(Equal(data))
})
}) })

View file

@ -408,6 +408,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
} }
client.HandleMessage(ticket, protocol.Encryption1RTT) client.HandleMessage(ticket, protocol.Encryption1RTT)
} }
if sendPostHandshakeMessageToClient { if sendPostHandshakeMessageToClient {
fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel) fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel)
client.HandleMessage(data, messageToReplaceEncLevel) client.HandleMessage(data, messageToReplaceEncLevel)