use the new crypto/tls QUIC Transport (#3860)

This commit is contained in:
Marten Seemann 2023-07-01 11:15:00 -07:00 committed by GitHub
parent 4998733ae1
commit 3d89e545d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 2197 additions and 1509 deletions

View file

@ -1,11 +1,5 @@
version: 2.1
executors:
test-go119:
docker:
- image: "cimg/go:1.19"
environment:
runrace: true
TIMESCALE_FACTOR: 3
test-go120:
docker:
- image: "cimg/go:1.20"
@ -15,7 +9,7 @@ executors:
jobs:
"test": &test
executor: test-go119
executor: test-go120
steps:
- checkout
- run:
@ -39,14 +33,10 @@ jobs:
- run:
name: "Run version negotiation tests with qlog"
command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation -- -qlog
go119:
<<: *test
go120:
<<: *test
executor: test-go120
workflows:
workflow:
jobs:
- go119
- go120

View file

@ -4,7 +4,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: [ "1.19.x", "1.20.x" ]
go: [ "1.20.x", "1.21.0-rc.2" ]
runs-on: ${{ fromJSON(vars['CROSS_COMPILE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
name: "Cross Compilation (Go ${{matrix.go}})"
steps:

View file

@ -5,7 +5,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: [ "1.19.x", "1.20.x" ]
go: [ "1.20.x", "1.21.0-rc.2" ]
runs-on: ${{ fromJSON(vars['INTEGRATION_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
env:
DEBUG: false # set this to true to export qlogs and save them as artifacts

View file

@ -7,7 +7,7 @@ jobs:
fail-fast: false
matrix:
os: [ "ubuntu", "windows", "macos" ]
go: [ "1.19.x", "1.20.x" ]
go: [ "1.20.x", "1.21.0-rc.2" ]
runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }}
name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }})
steps:

View file

@ -1,4 +1,6 @@
run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
depguard:
type: blacklist

View file

@ -220,7 +220,8 @@ quic-go always aims to support the latest two Go releases.
### Dependency on forked crypto/tls
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
## Contributing

View file

@ -8,6 +8,7 @@ coverage:
- http3/gzip_reader.go
- interop/
- internal/ackhandler/packet_linkedlist.go
- internal/handshake/cipher_suite.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/newconnectionid_linkedlist.go
- internal/utils/packetinterval_linkedlist.go

View file

@ -52,7 +52,7 @@ type streamManager interface {
}
type cryptoStreamHandler interface {
RunHandshake()
StartHandshake() error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
@ -98,15 +98,15 @@ type connRunner interface {
type handshakeRunner struct {
onReceivedParams func(*wire.TransportParameters)
onError func(error)
onReceivedReadKeys func()
dropKeys func(protocol.EncryptionLevel)
onHandshakeComplete func()
}
func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) }
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
func (r *handshakeRunner) OnReceivedReadKeys() { r.onReceivedReadKeys() }
type closeError struct {
err error
@ -329,14 +329,13 @@ var newConnection = func(
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
s.oneRTTStream,
clientDestConnID,
conn.LocalAddr(),
conn.RemoteAddr(),
params,
&handshakeRunner{
onReceivedParams: s.handleTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
onReceivedReadKeys: s.receivedReadKeys,
onHandshakeComplete: func() {
runner.Retire(clientDestConnID)
close(s.handshakeCompleteChan)
@ -418,6 +417,7 @@ var newClientConnection = func(
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -448,14 +448,13 @@ var newClientConnection = func(
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream,
handshakeStream,
oneRTTStream,
destConnID,
conn.LocalAddr(),
conn.RemoteAddr(),
params,
&handshakeRunner{
onReceivedParams: s.handleTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
onReceivedReadKeys: s.receivedReadKeys,
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
},
tlsConf,
@ -467,7 +466,7 @@ var newClientConnection = func(
)
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
@ -530,11 +529,9 @@ func (s *connection) run() error {
s.timer = *newTimer()
handshaking := make(chan struct{})
go func() {
defer close(handshaking)
s.cryptoStreamHandler.RunHandshake()
}()
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
return err
}
go func() {
if err := s.sendQueue.Run(); err != nil {
s.destroyImpl(err)
@ -686,7 +683,6 @@ runLoop:
}
s.cryptoStreamHandler.Close()
<-handshaking
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(&closeErr)
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
@ -717,7 +713,9 @@ func (s *connection) supportsDatagrams() bool {
func (s *connection) ConnectionState() ConnectionState {
s.connStateMutex.Lock()
defer s.connStateMutex.Unlock()
s.connState.TLS = s.cryptoStreamHandler.ConnectionState()
cs := s.cryptoStreamHandler.ConnectionState()
s.connState.TLS = cs.ConnectionState
s.connState.Used0RTT = cs.Used0RTT
return s.connState
}
@ -786,7 +784,7 @@ func (s *connection) handleHandshakeComplete() {
if err != nil {
s.closeLocal(err)
}
if ticket != nil {
if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled
s.oneRTTStream.Write(ticket)
for s.oneRTTStream.HasData() {
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
@ -1378,17 +1376,14 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame
}
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
if err != nil {
return err
return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
}
if encLevelChanged {
func (s *connection) receivedReadKeys() {
// Queue all packets for decryption that have been undecryptable so far.
s.undecryptablePacketsToProcess = s.undecryptablePackets
s.undecryptablePackets = nil
}
return nil
}
func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error {
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
@ -1629,11 +1624,15 @@ func (s *connection) handleCloseError(closeErr *closeError) {
}
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
if s.tracer != nil {
s.tracer.DroppedEncryptionLevel(encLevel)
}
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
if err := s.cryptoStreamManager.Drop(encLevel); err != nil {
s.closeLocal(err)
return
}
if encLevel == protocol.Encryption0RTT {
s.streamsMap.ResetFor0RTT()
if err := s.connFlowController.Reset(); err != nil {
@ -1817,6 +1816,9 @@ func (s *connection) sendPackets(now time.Time) error {
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset})
}
s.windowUpdateQueue.QueueAll()
if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
s.queueControlFrame(cf)
}
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)

View file

@ -119,7 +119,7 @@ var _ = Describe("Connection", func() {
&protocol.DefaultConnectionIDGenerator{},
protocol.StatelessResetToken{},
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
nil, // tls.Config
&tls.Config{},
tokenGenerator,
false,
tracer,
@ -357,7 +357,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
Expect(conn.run()).To(MatchError(expectedErr))
}()
Expect(conn.handleFrame(&wire.ConnectionCloseFrame{
@ -385,7 +385,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
Expect(conn.run()).To(MatchError(testErr))
}()
ccf := &wire.ConnectionCloseFrame{
@ -432,7 +432,7 @@ var _ = Describe("Connection", func() {
runConn := func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
runErr <- conn.run()
}()
Eventually(areConnsRunning).Should(BeTrue())
@ -811,7 +811,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
expectReplaceWithClosed()
@ -853,7 +853,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -888,7 +888,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -913,7 +913,7 @@ var _ = Describe("Connection", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
err := conn.run()
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
@ -937,7 +937,7 @@ var _ = Describe("Connection", func() {
runErr := make(chan error)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
runErr <- conn.run()
}()
expectReplaceWithClosed()
@ -961,7 +961,7 @@ var _ = Describe("Connection", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
err := conn.run()
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
@ -1197,7 +1197,7 @@ var _ = Describe("Connection", func() {
runConn := func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
close(connDone)
}()
@ -1415,7 +1415,7 @@ var _ = Describe("Connection", func() {
})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1439,7 +1439,7 @@ var _ = Describe("Connection", func() {
})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1463,7 +1463,7 @@ var _ = Describe("Connection", func() {
})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1479,7 +1479,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Send(gomock.Any(), gomock.Any())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1496,7 +1496,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Send(gomock.Any(), gomock.Any())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1514,7 +1514,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Send(gomock.Any(), gomock.Any())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1540,7 +1540,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1562,7 +1562,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1580,7 +1580,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Available().Return(available)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1602,7 +1602,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().WouldBlock().AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
@ -1633,7 +1633,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} })
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
available := make(chan struct{}, 1)
@ -1664,7 +1664,7 @@ var _ = Describe("Connection", func() {
// don't EXPECT any calls to mconn.Write()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending() // no packet will get sent
@ -1687,7 +1687,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), conn.version).Return(shortHeaderPacket{PacketNumber: 1}, getPacketBuffer(), nil)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
conn.scheduleSending()
@ -1734,7 +1734,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
// don't EXPECT any calls to mconn.Write()
@ -1768,7 +1768,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
Eventually(written).Should(BeClosed())
@ -1832,7 +1832,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
@ -1864,7 +1864,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
close(conn.handshakeCompleteChan)
@ -1894,7 +1894,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
close(conn.handshakeCompleteChan)
@ -1941,7 +1941,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().StartHandshake()
conn.run()
}()
handshakeCtx := conn.HandshakeComplete()
@ -1974,7 +1974,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
@ -1997,7 +1997,7 @@ var _ = Describe("Connection", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
Expect(conn.run()).To(Succeed())
close(done)
}()
@ -2017,7 +2017,7 @@ var _ = Describe("Connection", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
err := conn.run()
Expect(err).To(MatchError(&qerr.ApplicationError{
ErrorCode: 0x1337,
@ -2069,7 +2069,7 @@ var _ = Describe("Connection", func() {
runConn := func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
}
@ -2171,7 +2171,7 @@ var _ = Describe("Connection", func() {
)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
err := conn.run()
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
@ -2196,7 +2196,7 @@ var _ = Describe("Connection", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
err := conn.run()
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
@ -2229,7 +2229,7 @@ var _ = Describe("Connection", func() {
// and not on the last network activity
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -2256,7 +2256,7 @@ var _ = Describe("Connection", func() {
conn.handshakeComplete = false
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
err := conn.run()
nerr, ok := err.(net.Error)
@ -2285,7 +2285,7 @@ var _ = Describe("Connection", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
close(conn.handshakeCompleteChan)
@ -2305,7 +2305,7 @@ var _ = Describe("Connection", func() {
conn.idleTimeout = 30 * time.Second
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -2336,7 +2336,7 @@ var _ = Describe("Connection", func() {
pto := conn.rttStats.PTO(true)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
close(conn.handshakeCompleteChan)
@ -2508,7 +2508,7 @@ var _ = Describe("Client Connection", func() {
conn.unpacker = unpacker
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7})
@ -2588,7 +2588,7 @@ var _ = Describe("Client Connection", func() {
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
running := make(chan struct{})
cryptoSetup.EXPECT().RunHandshake().Do(func() {
cryptoSetup.EXPECT().StartHandshake().Do(func() {
close(running)
conn.closeLocal(errors.New("early error"))
})
@ -2641,7 +2641,7 @@ var _ = Describe("Client Connection", func() {
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
errChan <- conn.run()
}()
connRunner.EXPECT().Remove(srcConnID)
@ -2666,7 +2666,7 @@ var _ = Describe("Client Connection", func() {
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
errChan <- conn.run()
}()
connRunner.EXPECT().Remove(srcConnID).MaxTimes(1)
@ -2774,7 +2774,7 @@ var _ = Describe("Client Connection", func() {
closed = false
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
errChan <- conn.run()
close(errChan)
}()

View file

@ -71,17 +71,9 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
b := s.msgBuf
s.msgBuf = nil
return b
}
func (s *cryptoStreamImpl) Finish() error {

View file

@ -8,7 +8,7 @@ import (
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) bool
HandleMessage([]byte, protocol.EncryptionLevel) error
}
type cryptoStreamManager struct {
@ -33,7 +33,7 @@ func newCryptoStreamManager(
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
@ -44,18 +44,39 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return false, err
return err
}
for {
data := str.GetCryptoData()
if data == nil {
return false, nil
return nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return true, str.Finish()
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
if !m.oneRTTStream.HasData() {
return nil
}
return m.oneRTTStream.PopCryptoFrame(maxSize)
}
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // 1-RTT keys should never get dropped.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
case protocol.Encryption0RTT:
return nil
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}
}

View file

@ -1,12 +1,9 @@
package quic
import (
"errors"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -35,9 +32,7 @@ var _ = Describe("Crypto Stream Manager", func() {
initialStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
initialStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed())
})
It("passes messages to the handshake stream", func() {
@ -46,9 +41,7 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
handshakeStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
})
It("passes messages to the 1-RTT stream", func() {
@ -57,9 +50,7 @@ var _ = Describe("Crypto Stream Manager", func() {
oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
oneRTTStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
Expect(csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)).To(Succeed())
})
It("doesn't call the message handler, if there's no message", func() {
@ -67,9 +58,7 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().HandleCryptoFrame(cf)
handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle
// don't EXPECT any calls to HandleMessage()
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
})
It("processes all messages", func() {
@ -80,40 +69,26 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake)
cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
gomock.InOrder(
handshakeStream.EXPECT().HandleCryptoFrame(cf),
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish(),
)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeTrue())
})
It("returns errors that occur when finishing a stream", func() {
testErr := errors.New("test error")
cf := &wire.CryptoFrame{Data: []byte("foobar")}
gomock.InOrder(
handshakeStream.EXPECT().HandleCryptoFrame(cf),
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish().Return(testErr),
)
_, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).To(MatchError(err))
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
})
It("errors for unknown encryption levels", func() {
_, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42)
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("received CRYPTO frame with unexpected encryption level"))
})
It("drops Initial", func() {
initialStream.EXPECT().Finish()
Expect(csm.Drop(protocol.EncryptionInitial)).To(Succeed())
})
It("drops Handshake", func() {
handshakeStream.EXPECT().Finish()
Expect(csm.Drop(protocol.EncryptionHandshake)).To(Succeed())
})
It("no-ops when dropping 0-RTT", func() {
Expect(csm.Drop(protocol.Encryption0RTT)).To(Succeed())
})
})

View file

@ -1,7 +1,6 @@
package quic
import (
"crypto/rand"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
@ -12,16 +11,6 @@ import (
. "github.com/onsi/gomega"
)
func createHandshakeMessage(len int) []byte {
msg := make([]byte, 4+len)
rand.Read(msg[:1]) // random message type
msg[1] = uint8(len >> 16)
msg[2] = uint8(len >> 8)
msg[3] = uint8(len)
rand.Read(msg[4:])
return msg
}
var _ = Describe("Crypto Stream", func() {
var str cryptoStream
@ -31,21 +20,11 @@ var _ = Describe("Crypto Stream", func() {
Context("handling incoming data", func() {
It("handles in-order CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})).To(Succeed())
Expect(str.GetCryptoData()).To(Equal([]byte("foo")))
Expect(str.GetCryptoData()).To(BeNil())
})
It("handles multiple messages in one CRYPTO frame", func() {
msg1 := createHandshakeMessage(6)
msg2 := createHandshakeMessage(10)
msg := append(append([]byte{}, msg1...), msg2...)
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg1))
Expect(str.GetCryptoData()).To(Equal(msg2))
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})).To(Succeed())
Expect(str.GetCryptoData()).To(Equal([]byte("bar")))
Expect(str.GetCryptoData()).To(BeNil())
})
@ -59,42 +38,17 @@ var _ = Describe("Crypto Stream", func() {
}))
})
It("handles messages split over multiple CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg[:4],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(BeNil())
err = str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: 4,
Data: msg[4:],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil())
})
It("handles out-of-order CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: 4,
Data: msg[4:],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(BeNil())
err = str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg[:4],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Offset: 3, Data: []byte("bar")})).To(Succeed())
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})).To(Succeed())
Expect(str.GetCryptoData()).To(Equal([]byte("foobar")))
Expect(str.GetCryptoData()).To(BeNil())
})
Context("finishing", func() {
It("errors if there's still data to read after finishing", func() {
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: createHandshakeMessage(5),
Data: []byte("foobar"),
Offset: 10,
})).To(Succeed())
Expect(str.Finish()).To(MatchError(&qerr.TransportError{
@ -120,7 +74,7 @@ var _ = Describe("Crypto Stream", func() {
It("rejects new crypto data after finishing", func() {
Expect(str.Finish()).To(Succeed())
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: createHandshakeMessage(5),
Data: []byte("foo"),
})).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received crypto data after change of encryption level",
@ -128,15 +82,14 @@ var _ = Describe("Crypto Stream", func() {
})
It("ignores crypto data below the maximum offset received before finishing", func() {
msg := createHandshakeMessage(15)
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg,
Data: []byte("foobar"),
})).To(Succeed())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(Equal([]byte("foobar")))
Expect(str.Finish()).To(Succeed())
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: protocol.ByteCount(len(msg) - 6),
Data: []byte("foobar"),
Offset: 2,
Data: []byte("foo"),
})).To(Succeed())
})
})

View file

@ -43,26 +43,24 @@ func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */)
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnError(error)
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type runner struct {
client, server *handshake.CryptoSetup
handshakeComplete chan<- struct{}
}
var _ handshakeRunner = &runner{}
func newRunner(client, server *handshake.CryptoSetup) *runner {
return &runner{client: client, server: server}
func newRunner(handshakeComplete chan<- struct{}) *runner {
return &runner{handshakeComplete: handshakeComplete}
}
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
func (r *runner) OnHandshakeComplete() {}
func (r *runner) OnError(err error) {
(*r.client).Close()
(*r.server).Close()
log.Fatal("runner error:", err)
func (r *runner) OnReceivedReadKeys() {}
func (r *runner) OnHandshakeComplete() {
close(r.handshakeComplete)
}
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
@ -71,16 +69,16 @@ const alpn = "fuzz"
func main() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
var client, server handshake.CryptoSetup
runner := newRunner(&client, &server)
clientHandshakeCompleted := make(chan struct{})
client, _ = handshake.NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
runner,
newRunner(clientHandshakeCompleted),
&tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: "localhost",
NextProtos: []string{alpn},
RootCAs: testdata.GetRootCA(),
@ -96,14 +94,14 @@ func main() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
config := testdata.GetTLSConfig()
config.NextProtos = []string{alpn}
serverHandshakeCompleted := make(chan struct{})
server = handshake.NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
runner,
newRunner(serverHandshakeCompleted),
config,
false,
utils.NewRTTStats(),
@ -112,17 +110,13 @@ func main() {
protocol.Version1,
)
serverHandshakeCompleted := make(chan struct{})
go func() {
defer close(serverHandshakeCompleted)
server.RunHandshake()
}()
if err := client.StartHandshake(); err != nil {
log.Fatal(err)
}
clientHandshakeCompleted := make(chan struct{})
go func() {
defer close(clientHandshakeCompleted)
client.RunHandshake()
}()
if err := server.StartHandshake(); err != nil {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
@ -137,10 +131,14 @@ messageLoop:
select {
case c := <-cChunkChan:
messages = append(messages, c.data)
server.HandleMessage(c.data, c.encLevel)
if err := server.HandleMessage(c.data, c.encLevel); err != nil {
log.Fatal(err)
}
case c := <-sChunkChan:
messages = append(messages, c.data)
client.HandleMessage(c.data, c.encLevel)
if err := client.HandleMessage(c.data, c.encLevel); err != nil {
log.Fatal(err)
}
case <-done:
break messageLoop
}

View file

@ -11,7 +11,6 @@ import (
"log"
"math"
mrand "math/rand"
"sync"
"time"
"github.com/quic-go/quic-go/fuzzing/internal/helper"
@ -157,39 +156,24 @@ func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */)
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnError(error)
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type runner struct {
sync.Mutex
errored bool
client, server *handshake.CryptoSetup
handshakeComplete chan<- struct{}
}
var _ handshakeRunner = &runner{}
func newRunner(client, server *handshake.CryptoSetup) *runner {
return &runner{client: client, server: server}
func newRunner(handshakeComplete chan<- struct{}) *runner {
return &runner{handshakeComplete: handshakeComplete}
}
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
func (r *runner) OnHandshakeComplete() {}
func (r *runner) OnError(err error) {
r.Lock()
defer r.Unlock()
if r.errored {
return
}
r.errored = true
(*r.client).Close()
(*r.server).Close()
}
func (r *runner) Errored() bool {
r.Lock()
defer r.Unlock()
return r.errored
func (r *runner) OnReceivedReadKeys() {}
func (r *runner) OnHandshakeComplete() {
close(r.handshakeComplete)
}
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
@ -270,6 +254,7 @@ func Fuzz(data []byte) int {
}
clientConf := &tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: "localhost",
NextProtos: []string{alpn},
RootCAs: certPool,
@ -287,6 +272,7 @@ func Fuzz(data []byte) int {
func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
serverConf := &tls.Config{
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{*cert},
NextProtos: []string{alpn},
SessionTicketKey: sessionTicketKey,
@ -373,15 +359,14 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
var client, server handshake.CryptoSetup
runner := newRunner(&client, &server)
clientHandshakeCompleted := make(chan struct{})
client, _ = handshake.NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
clientTP,
runner,
newRunner(clientHandshakeCompleted),
clientConf,
enable0RTTClient,
utils.NewRTTStats(),
@ -390,15 +375,15 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
protocol.Version1,
)
serverHandshakeCompleted := make(chan struct{})
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server = handshake.NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
serverTP,
runner,
newRunner(serverHandshakeCompleted),
serverConf,
enable0RTTServer,
utils.NewRTTStats(),
@ -411,17 +396,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
return -1
}
serverHandshakeCompleted := make(chan struct{})
go func() {
defer close(serverHandshakeCompleted)
server.RunHandshake()
}()
if err := client.StartHandshake(); err != nil {
log.Fatal(err)
}
clientHandshakeCompleted := make(chan struct{})
go func() {
defer close(clientHandshakeCompleted)
client.RunHandshake()
}()
if err := server.StartHandshake(); err != nil {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
@ -441,7 +422,9 @@ messageLoop:
b = data
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
}
server.HandleMessage(b, encLevel)
if err := server.HandleMessage(b, encLevel); err != nil {
break messageLoop
}
case c := <-sChunkChan:
b := c.data
encLevel := c.encLevel
@ -450,11 +433,10 @@ messageLoop:
b = data
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
}
client.HandleMessage(b, encLevel)
case <-done: // test done
if err := client.HandleMessage(b, encLevel); err != nil {
break messageLoop
}
if runner.Errored() {
case <-done: // test done
break messageLoop
}
}
@ -462,9 +444,6 @@ messageLoop:
<-done
_ = client.ConnectionState()
_ = server.ConnectionState()
if runner.Errored() {
return 0
}
sealer, err := client.Get1RTTSealer()
if err != nil {

3
go.mod
View file

@ -8,8 +8,7 @@ require (
github.com/onsi/ginkgo/v2 v2.9.5
github.com/onsi/gomega v1.27.6
github.com/quic-go/qpack v0.4.0
github.com/quic-go/qtls-go1-19 v0.3.2
github.com/quic-go/qtls-go1-20 v0.2.2
github.com/quic-go/qtls-go1-20 v0.3.0
golang.org/x/crypto v0.4.0
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/net v0.10.0

6
go.sum
View file

@ -90,10 +90,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk=
github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=

View file

@ -15,7 +15,6 @@ import (
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
@ -402,7 +401,7 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
return nil, newConnError(ErrCodeGeneralProtocolError, err)
}
connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
connState := conn.ConnectionState().TLS
res := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,

View file

@ -577,7 +577,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
return newStreamError(ErrCodeGeneralProtocolError, err)
}
connState := conn.ConnectionState().TLS.ConnectionState
connState := conn.ConnectionState().TLS
req.TLS = &connState
req.RemoteAddr = conn.RemoteAddr().String()
body := newRequestBody(newStream(str, onFrameError))

View file

@ -926,7 +926,7 @@ var _ = Describe("Server", func() {
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
})
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
@ -954,7 +954,7 @@ var _ = Describe("Server", func() {
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
})
It("works if GetConfigForClient returns a nil tls.Config", func() {
@ -967,7 +967,7 @@ var _ = Describe("Server", func() {
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
@ -985,7 +985,7 @@ var _ = Describe("Server", func() {
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
// check that the original config was not modified
Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
})

View file

@ -136,10 +136,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk=
github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
@ -185,7 +183,6 @@ golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnf
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=

View file

@ -198,7 +198,10 @@ var _ = Describe("Handshake tests", func() {
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
Expect(transportErr.Error()).To(ContainSubstring("tls: bad certificate"))
Expect(transportErr.Error()).To(Or(
ContainSubstring("tls: certificate required"),
ContainSubstring("tls: bad certificate"),
))
})
It("uses the ServerName in the tls.Config", func() {

View file

@ -5,7 +5,7 @@ import (
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"github.com/quic-go/quic-go"
@ -14,16 +14,15 @@ import (
)
type clientSessionCache struct {
mutex sync.Mutex
cache map[string]*tls.ClientSessionState
cache tls.ClientSessionCache
gets chan<- string
puts chan<- string
}
func newClientSessionCache(gets, puts chan<- string) *clientSessionCache {
func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
return &clientSessionCache{
cache: make(map[string]*tls.ClientSessionState),
cache: cache,
gets: gets,
puts: puts,
}
@ -32,29 +31,25 @@ func newClientSessionCache(gets, puts chan<- string) *clientSessionCache {
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
session, ok := c.cache.Get(sessionKey)
c.gets <- sessionKey
c.mutex.Lock()
session, ok := c.cache[sessionKey]
c.mutex.Unlock()
return session, ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
c.cache.Put(sessionKey, cs)
c.puts <- sessionKey
c.mutex.Lock()
c.cache[sessionKey] = cs
c.mutex.Unlock()
}
var _ = Describe("TLS session resumption", func() {
It("uses session resumption", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(gets, puts)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
@ -96,7 +91,7 @@ var _ = Describe("TLS session resumption", func() {
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(gets, puts)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
@ -109,7 +104,9 @@ var _ = Describe("TLS session resumption", func() {
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())

View file

@ -0,0 +1,804 @@
//go:build !go1.21
package self_test
import (
"context"
"crypto/tls"
"fmt"
"io"
mrand "math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("0-RTT", func() {
rtt := scaleDuration(5 * time.Millisecond)
runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) {
var num0RTTPackets uint32 // to be used as an atomic
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
for len(data) > 0 {
if !wire.IsLongHeaderPacket(data[0]) {
break
}
hdr, _, rest, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
break
}
data = rest
}
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
return proxy, &num0RTTPackets
}
dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) {
tlsConf := getTLSConfig()
if serverConf == nil {
serverConf = getQuicConfig(nil)
}
serverConf.Allow0RTT = true
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
serverConf,
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
// dial the first connection in order to receive a session ticket
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
<-conn.Context().Done()
}()
clientConf := getTLSClientConfig()
gets := make(chan string, 100)
puts := make(chan string, 100)
clientConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(100), gets, puts)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Eventually(puts).Should(Receive())
// received the session ticket. We're done here.
Expect(conn.CloseWithError(0, "")).To(Succeed())
Eventually(done).Should(BeClosed())
return tlsConf, clientConf
}
transfer0RTTData := func(
ln *quic.EarlyListener,
proxyPort int,
connIDLen int,
clientTLSConf *tls.Config,
clientConf *quic.Config,
testdata []byte, // data to transfer
) {
// accept the second connection, and receive the data sent in 0-RTT
done := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testdata))
Expect(str.Close()).To(Succeed())
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
<-conn.Context().Done()
close(done)
}()
if clientConf == nil {
clientConf = getQuicConfig(nil)
}
var conn quic.EarlyConnection
if connIDLen == 0 {
var err error
conn, err = quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxyPort),
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
} else {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer tr.Close()
conn, err = tr.DialEarly(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
}
defer conn.CloseWithError(0, "")
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(testdata)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
<-conn.HandshakeComplete()
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed())
}
check0RTTRejected := func(
ln *quic.EarlyListener,
proxyPort int,
clientConf *tls.Config,
) {
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxyPort),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(make([]byte, 3000))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
// make sure the server doesn't process the data
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
defer cancel()
serverConn, err := ln.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
_, err = serverConn.AcceptUniStream(ctx)
Expect(err).To(Equal(context.DeadlineExceeded))
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
Eventually(conn.Context().Done()).Should(BeClosed())
}
// can be used to extract 0-RTT from a packetTracer
get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
var zeroRTTPackets []protocol.PacketNumber
for _, p := range packets {
if p.hdr.Type == protocol.PacketType0RTT {
zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber)
}
}
return zeroRTTPackets
}
for _, l := range []int{0, 15} {
connIDLen := l
It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
transfer0RTTData(
ln,
proxy.LocalPort(),
connIDLen,
clientTLSConf,
getQuicConfig(nil),
PRData,
)
var numNewConnIDs int
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, f := range p.frames {
if _, ok := f.(*logging.NewConnectionIDFrame); ok {
numNewConnIDs++
}
}
}
if connIDLen == 0 {
Expect(numNewConnIDs).To(BeZero())
} else {
Expect(numNewConnIDs).ToNot(BeZero())
}
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
})
}
// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
It("waits for a connection until the handshake is done", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
zeroRTTData := GeneratePRData(5 << 10)
oneRTTData := PRData
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
// now accept the second connection, and receive the 0-RTT data
go func() {
defer GinkgoRecover()
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(zeroRTTData))
str, err = conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err = io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(oneRTTData))
Expect(conn.CloseWithError(0, "")).To(Succeed())
}()
proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
firstStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = firstStr.Write(zeroRTTData)
Expect(err).ToNot(HaveOccurred())
Expect(firstStr.Close()).To(Succeed())
// wait for the handshake to complete
Eventually(conn.HandshakeComplete()).Should(BeClosed())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(PRData)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
<-conn.Context().Done()
// check that 0-RTT packets only contain STREAM frames for the first stream
var num0RTT int
for _, p := range tracer.getRcvdLongHeaderPackets() {
if p.hdr.Header.Type != protocol.PacketType0RTT {
continue
}
for _, f := range p.frames {
sf, ok := f.(*logging.StreamFrame)
if !ok {
continue
}
num0RTT++
Expect(sf.StreamID).To(Equal(firstStr.StreamID()))
}
}
fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT)
Expect(num0RTT).ToNot(BeZero())
})
It("transfers 0-RTT data, when 0-RTT packets are lost", func() {
var (
num0RTTPackets uint32 // to be used as an atomic
num0RTTDropped uint32
)
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
if wire.IsLongHeaderPacket(data[0]) {
hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
}
}
return rtt / 2
},
DropPacket: func(_ quicproxy.Direction, data []byte) bool {
if !wire.IsLongHeaderPacket(data[0]) {
return false
}
hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
// drop 25% of the 0-RTT packets
drop := mrand.Intn(4) == 0
if drop {
atomic.AddUint32(&num0RTTDropped, 1)
}
return drop
}
return false
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
num0RTT := atomic.LoadUint32(&num0RTTPackets)
numDropped := atomic.LoadUint32(&num0RTTDropped)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
Expect(numDropped).ToNot(BeZero())
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
})
It("retransmits all 0-RTT data when the server performs a Retry", func() {
var mutex sync.Mutex
var firstConnID, secondConnID *protocol.ConnectionID
var firstCounter, secondCounter protocol.ByteCount
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
for len(data) > 0 {
hdr, _, rest, err := wire.ParsePacket(data)
if err != nil {
return
}
data = rest
if hdr.Type == protocol.PacketType0RTT {
n += hdr.Length - 16 /* AEAD tag */
}
}
return
}
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
connID, err := wire.ParseConnectionID(data, 0)
Expect(err).ToNot(HaveOccurred())
mutex.Lock()
defer mutex.Unlock()
if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 {
if firstConnID == nil {
firstConnID = &connID
firstCounter += zeroRTTBytes
} else if firstConnID != nil && *firstConnID == connID {
Expect(secondConnID).To(BeNil())
firstCounter += zeroRTTBytes
} else if secondConnID == nil {
secondConnID = &connID
secondCounter += zeroRTTBytes
} else if secondConnID != nil && *secondConnID == connID {
secondCounter += zeroRTTBytes
} else {
Fail("received 3 connection IDs on 0-RTT packets")
}
}
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
mutex.Lock()
defer mutex.Unlock()
Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
})
It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
const maxStreams = 1
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams,
}))
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
// The client remembers the old limit and refuses to open a new stream.
_, err = conn.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("too many open streams"))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = conn.OpenUniStreamSync(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
Expect(conn.CloseWithError(0, "")).To(Succeed())
})
It("rejects 0-RTT when the server's stream limit decreased", func() {
const maxStreams = 42
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams,
}))
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams - 1,
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the ALPN changed", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
// now close the listener and dial new connection with a different ALPN
clientConf.NextProtos = []string{"new-alpn"}
tlsConf.NextProtos = []string{"new-alpn"}
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the application doesn't allow it", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
// now close the listener and dial new connection with a different ALPN
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: false, // application rejects 0-RTT
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
})
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
secondConf := getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
})
addFlowControlLimit(secondConf, 100)
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
secondConf,
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
written := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(written)
_, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
Eventually(written).Should(BeClosed())
serverConn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
rstr, err := serverConn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rstr)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue())
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
Eventually(conn.Context().Done()).Should(BeClosed())
var processedFirst bool
for _, p := range tracer.getRcvdLongHeaderPackets() {
for _, f := range p.frames {
if sf, ok := f.(*logging.StreamFrame); ok {
if !processedFirst {
// The first STREAM should have been sent in a 0-RTT packet.
// Due to the flow control limit, the STREAM frame was limit to the first 3 bytes.
Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT))
Expect(sf.Length).To(BeEquivalentTo(3))
processedFirst = true
} else {
Fail("STREAM was shouldn't have been sent in 0-RTT")
}
}
}
}
},
Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }),
Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }),
)
for _, l := range []int{0, 15} {
connIDLen := l
It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
// now dial new connection with different transport parameters
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
// The client remembers that it was allowed to open 2 uni-directional streams.
firstStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
written := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := firstStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
}()
secondStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := secondStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = conn.AcceptStream(ctx)
Expect(err).To(MatchError(quic.Err0RTTRejected))
Eventually(written).Should(Receive())
Eventually(written).Should(Receive())
_, err = firstStr.Write([]byte("foobar"))
Expect(err).To(MatchError(quic.Err0RTTRejected))
_, err = conn.OpenUniStream()
Expect(err).To(MatchError(quic.Err0RTTRejected))
_, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected))
newConn := conn.NextConnection()
str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("too many open streams"))
_, err = str.Write([]byte("second flight"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
})
}
It("queues 0-RTT packets, if the Initial is delayed", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
return rtt/2 + rtt
}
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
})
})

View file

@ -1,3 +1,5 @@
//go:build go1.21
package self_test
import (
@ -21,6 +23,36 @@ import (
. "github.com/onsi/gomega"
)
type metadataClientSessionCache struct {
toAdd []byte
restored func([]byte)
cache tls.ClientSessionCache
}
func (m metadataClientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
session, ok := m.cache.Get(key)
if !ok || session == nil {
return session, ok
}
ticket, state, err := session.ResumptionState()
Expect(err).ToNot(HaveOccurred())
Expect(state.Extra).To(HaveLen(2)) // ours, and the quic-go's
m.restored(state.Extra[1])
session, err = tls.NewResumptionState(ticket, state)
Expect(err).ToNot(HaveOccurred())
return session, true
}
func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionState) {
ticket, state, err := session.ResumptionState()
Expect(err).ToNot(HaveOccurred())
state.Extra = append(state.Extra, m.toAdd)
session, err = tls.NewResumptionState(ticket, state)
Expect(err).ToNot(HaveOccurred())
m.cache.Put(key, session)
}
var _ = Describe("0-RTT", func() {
rtt := scaleDuration(5 * time.Millisecond)
@ -49,15 +81,14 @@ var _ = Describe("0-RTT", func() {
return proxy, &num0RTTPackets
}
dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) {
tlsConf := getTLSConfig()
dialAndReceiveSessionTicket := func(serverTLSConf *tls.Config, serverConf *quic.Config, clientTLSConf *tls.Config) {
if serverConf == nil {
serverConf = getQuicConfig(nil)
}
serverConf.Allow0RTT = true
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
serverTLSConf,
serverConf,
)
Expect(err).ToNot(HaveOccurred())
@ -80,14 +111,16 @@ var _ = Describe("0-RTT", func() {
<-conn.Context().Done()
}()
clientConf := getTLSClientConfig()
gets := make(chan string, 100)
puts := make(chan string, 100)
clientConf.ClientSessionCache = newClientSessionCache(gets, puts)
cache := clientTLSConf.ClientSessionCache
if cache == nil {
cache = tls.NewLRUClientSessionCache(100)
}
clientTLSConf.ClientSessionCache = newClientSessionCache(cache, make(chan string, 100), puts)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
clientTLSConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
@ -95,7 +128,6 @@ var _ = Describe("0-RTT", func() {
// received the session ticket. We're done here.
Expect(conn.CloseWithError(0, "")).To(Succeed())
Eventually(done).Should(BeClosed())
return tlsConf, clientConf
}
transfer0RTTData := func(
@ -118,7 +150,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testdata))
Expect(str.Close()).To(Succeed())
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
<-conn.Context().Done()
close(done)
}()
@ -162,7 +194,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
<-conn.HandshakeComplete()
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
@ -186,14 +218,14 @@ var _ = Describe("0-RTT", func() {
_, err = str.Write(make([]byte, 3000))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeFalse())
Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
// make sure the server doesn't process the data
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
defer cancel()
serverConn, err := ln.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeFalse())
Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
_, err = serverConn.AcceptUniStream(ctx)
Expect(err).To(Equal(context.DeadlineExceeded))
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
@ -215,7 +247,9 @@ var _ = Describe("0-RTT", func() {
connIDLen := l
It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientTLSConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
@ -266,7 +300,9 @@ var _ = Describe("0-RTT", func() {
// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
It("waits for a connection until the handshake is done", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
zeroRTTData := GeneratePRData(5 << 10)
oneRTTData := PRData
@ -351,7 +387,9 @@ var _ = Describe("0-RTT", func() {
num0RTTDropped uint32
)
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
@ -412,7 +450,9 @@ var _ = Describe("0-RTT", func() {
var firstConnID, secondConnID *protocol.ConnectionID
var firstCounter, secondCounter protocol.ByteCount
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
for len(data) > 0 {
@ -485,9 +525,11 @@ var _ = Describe("0-RTT", func() {
It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
const maxStreams = 1
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams,
}))
}), clientConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
@ -524,15 +566,17 @@ var _ = Describe("0-RTT", func() {
defer cancel()
_, err = conn.OpenUniStreamSync(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
Expect(conn.CloseWithError(0, "")).To(Succeed())
})
It("rejects 0-RTT when the server's stream limit decreased", func() {
const maxStreams = 42
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams,
}))
}), clientConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
@ -548,6 +592,7 @@ var _ = Describe("0-RTT", func() {
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them.
@ -558,11 +603,15 @@ var _ = Describe("0-RTT", func() {
})
It("rejects 0-RTT when the ALPN changed", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now close the listener and dial new connection with a different ALPN
clientConf.NextProtos = []string{"new-alpn"}
// switch to different ALPN on the server side
tlsConf.NextProtos = []string{"new-alpn"}
// Append to the client's ALPN.
// crypto/tls will attempt to resume with the ALPN from the original connection
clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn")
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
@ -587,7 +636,9 @@ var _ = Describe("0-RTT", func() {
})
It("rejects 0-RTT when the application doesn't allow it", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now close the listener and dial new connection with a different ALPN
tracer := newPacketTracer()
@ -618,7 +669,9 @@ var _ = Describe("0-RTT", func() {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, firstConf, clientConf)
secondConf := getQuicConfig(&quic.Config{
Allow0RTT: true,
@ -662,7 +715,7 @@ var _ = Describe("0-RTT", func() {
data, err := io.ReadAll(rstr)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeTrue())
Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue())
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
Eventually(conn.Context().Done()).Should(BeClosed())
@ -691,7 +744,9 @@ var _ = Describe("0-RTT", func() {
connIDLen := l
It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now dial new connection with different transport parameters
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
@ -767,7 +822,9 @@ var _ = Describe("0-RTT", func() {
}
It("queues 0-RTT packets, if the Initial is delayed", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
@ -799,4 +856,86 @@ var _ = Describe("0-RTT", func() {
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
})
It("allows the application to attach data to the session ticket, for the server", func() {
tlsConf := getTLSConfig()
tlsConf.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) {
ss.Extra = append(ss.Extra, []byte("foobar"))
return tlsConf.EncryptTicket(cs, ss)
}
var unwrapped bool
tlsConf.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) {
defer GinkgoRecover()
state, err := tlsConf.DecryptTicket(identity, cs)
if err != nil {
return nil, err
}
Expect(state.Extra).To(HaveLen(2))
Expect(state.Extra[1]).To(Equal([]byte("foobar")))
unwrapped = true
return state, nil
}
clientTLSConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
transfer0RTTData(
ln,
ln.Addr().(*net.UDPAddr).Port,
10,
clientTLSConf,
getQuicConfig(nil),
PRData,
)
Expect(unwrapped).To(BeTrue())
})
It("allows the application to attach data to the session ticket, for the client", func() {
tlsConf := getTLSConfig()
clientTLSConf := getTLSClientConfig()
var restored bool
clientTLSConf.ClientSessionCache = &metadataClientSessionCache{
toAdd: []byte("foobar"),
restored: func(b []byte) {
defer GinkgoRecover()
Expect(b).To(Equal([]byte("foobar")))
restored = true
},
cache: tls.NewLRUClientSessionCache(100),
}
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
transfer0RTTData(
ln,
ln.Addr().(*net.UDPAddr).Port,
10,
clientTLSConf,
getQuicConfig(nil),
PRData,
)
Expect(restored).To(BeTrue())
})
})

View file

@ -2,6 +2,7 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
@ -336,12 +337,14 @@ type ClientHelloInfo struct {
// ConnectionState records basic details about a QUIC connection
type ConnectionState struct {
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS handshake.ConnectionState
TLS tls.ConnectionState
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
// If datagram support was negotiated, datagrams can be sent and received using the
// SendMessage and ReceiveMessage methods on the Connection.
SupportsDatagrams bool
// Used0RTT says if 0-RTT resumption was used.
Used0RTT bool
// Version is the QUIC version of the QUIC connection.
Version VersionNumber
}

View file

@ -5,11 +5,10 @@ import (
"encoding/binary"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
)
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {

View file

@ -0,0 +1,104 @@
package handshake
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
// These cipher suite implementations are copied from the standard library crypto/tls package.
const aeadNonceLength = 12
type cipherSuite struct {
ID uint16
Hash crypto.Hash
KeyLen int
AEAD func(key, nonceMask []byte) cipher.AEAD
}
func (s cipherSuite) IVLen() int { return aeadNonceLength }
func getCipherSuite(id uint16) *cipherSuite {
switch id {
case tls.TLS_AES_128_GCM_SHA256:
return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
case tls.TLS_CHACHA20_POLY1305_SHA256:
return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
case tls.TLS_AES_256_GCM_SHA384:
return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadAESGCMTLS13}
default:
panic(fmt.Sprintf("unknown cypher suite: %d", id))
}
}
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}

View file

@ -7,9 +7,8 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
@ -25,96 +24,20 @@ type quicVersionContextKey struct{}
var QUICVersionContextKey = &quicVersionContextKey{}
// TLS unexpected_message alert
const alertUnexpectedMessage uint8 = 10
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeNewSessionTicket messageType = 4
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeNewSessionTicket:
return "NewSessionTicket"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
}
var _ net.Conn = &conn{}
func newConn(local, remote net.Addr) net.Conn {
return &conn{
localAddr: local,
remoteAddr: remote,
}
}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
type cryptoSetup struct {
tlsConf *tls.Config
extraConf *qtls.ExtraConfig
conn *qtls.Conn
conn *qtls.QUICConn
version protocol.VersionNumber
messageChan chan []byte
isReadingHandshakeMessage chan struct{}
readFirstHandshakeMessage bool
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
paramsChan <-chan []byte
runner handshakeRunner
alertChan chan uint8
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
handshakeDone chan struct{}
// is closed when Close() is called
closeChan chan struct{}
zeroRTTParameters *wire.TransportParameters
clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT bool
@ -129,9 +52,6 @@ type cryptoSetup struct {
handshakeCompleteTime time.Time
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
@ -143,23 +63,20 @@ type cryptoSetup struct {
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
used0RTT atomic.Bool
oneRTTStream io.Writer
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
}
var (
_ qtls.RecordLayer = &cryptoSetup{}
_ CryptoSetup = &cryptoSetup{}
)
var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
initialStream, handshakeStream, oneRTTStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
@ -172,28 +89,33 @@ func NewCryptoSetupClient(
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
connID,
tp,
runner,
tlsConf,
enable0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.conn = qtls.QUICClient(quicConf)
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs, clientHelloWritten
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
initialStream, handshakeStream, oneRTTStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
@ -206,29 +128,32 @@ func NewCryptoSetupServer(
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
connID,
tp,
runner,
tlsConf,
allow0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
cs.allow0RTT = allow0RTT
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
cs.tlsConf = quicConf.TLSConfig
cs.conn = qtls.QUICServer(quicConf)
return cs
}
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
initialStream, handshakeStream, oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
@ -240,51 +165,23 @@ func newCryptoSetup(
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
zeroRTTParametersChan := make(chan *wire.TransportParameters, 1)
cs := &cryptoSetup{
tlsConf: tlsConf,
return &cryptoSetup{
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
allow0RTT: enable0RTT,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan struct{}),
zeroRTTParametersChan: zeroRTTParametersChan,
messageChan: make(chan []byte, 1),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
version: version,
}
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = math.MaxUint32
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
AlternativeRecordLayer: cs,
EnforceNextProtoSelection: true,
MaxEarlyData: maxEarlyData,
Accept0RTT: cs.accept0RTT,
Rejected0RTT: cs.rejected0RTT,
Enable0RTT: enable0RTT,
GetAppDataForSessionState: cs.marshalDataForSessionState,
SetAppDataFromSessionState: cs.handleDataFromSessionState,
}
return cs, zeroRTTParametersChan
}, zeroRTTParametersChan
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
@ -301,142 +198,100 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) RunHandshake() {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
handshakeErrChan := make(chan error, 1)
go func() {
defer close(h.handshakeDone)
if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil {
handshakeErrChan <- err
return
func (h *cryptoSetup) StartHandshake() error {
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
return wrapError(err)
}
if done {
break
}
}
close(handshakeComplete)
}()
if h.perspective == protocol.PerspectiveClient {
select {
case err := <-handshakeErrChan:
h.onError(0, err.Error())
return
case <-h.clientHelloWrittenChan:
}
}
select {
case <-handshakeComplete: // return when the handshake is done
h.mutex.Lock()
h.handshakeCompleteTime = time.Now()
h.mutex.Unlock()
h.runner.OnHandshakeComplete()
case <-h.closeChan:
// wait until the Handshake() go routine has returned
<-h.handshakeDone
case alert := <-h.alertChan:
handshakeErr := <-handshakeErrChan
h.onError(alert, handshakeErr.Error())
}
}
func (h *cryptoSetup) onError(alert uint8, message string) {
var err error
if alert == 0 {
err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message}
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.zeroRTTParametersChan <- h.zeroRTTParameters
} else {
err = qerr.NewLocalCryptoError(alert, message)
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
h.zeroRTTParametersChan <- nil
}
h.runner.OnError(err)
}
return nil
}
// Close closes the crypto setup.
// It aborts the handshake, if it is still running.
// It must only be called once.
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
return nil
return h.conn.Close()
}
// handleMessage handles a TLS handshake message.
// HandleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.onError(alertUnexpectedMessage, err.Error())
return false
}
if encLevel != protocol.Encryption1RTT {
select {
case h.messageChan <- data:
case <-h.handshakeDone: // handshake errored, nobody is going to consume this message
return false
}
}
if encLevel == protocol.Encryption1RTT {
h.messageChan <- data
h.handlePostHandshakeMessage()
return false
}
readLoop:
for {
select {
case data := <-h.paramsChan:
if data == nil {
h.onError(0x6d, "missing quic_transport_parameters extension")
} else {
h.handleTransportParameters(data)
}
case <-h.isReadingHandshakeMessage:
break readLoop
case <-h.handshakeDone:
break readLoop
case <-h.closeChan:
break readLoop
}
}
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
// but only if a handshake opener and sealer was created.
// Otherwise, a HelloRetryRequest was performed.
// We're done with the Handshake encryption level after processing the Finished message.
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
msgType == typeFinished
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello, typeServerHello:
expected = protocol.EncryptionInitial
case typeEncryptedExtensions,
typeCertificate,
typeCertificateRequest,
typeCertificateVerify,
typeFinished:
expected = protocol.EncryptionHandshake
case typeNewSessionTicket:
expected = protocol.Encryption1RTT
default:
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.handleMessage(data, encLevel); err != nil {
return wrapError(err)
}
return nil
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
return err
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
switch ev.Kind {
case qtls.QUICNoEvent:
return true, nil
case qtls.QUICSetReadSecret:
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case qtls.QUICSetWriteSecret:
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case qtls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data)
case qtls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
return false, nil
case qtls.QUICRejectedEarlyData:
h.rejected0RTT()
return false, nil
case qtls.QUICWriteData:
return false, h.WriteRecord(ev.Level, ev.Data)
case qtls.QUICHandshakeDone:
h.handshakeComplete()
return false, nil
default:
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
}
}
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
h.runner.OnError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
return err
}
h.peerParams = &tp
h.runner.OnReceivedParams(h.peerParams)
return nil
}
// must be called after receiving the transport parameters
@ -477,17 +332,32 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
return &tp, nil
}
// only valid for the server
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.extraConf.MaxEarlyData > 0 {
appData = (&sessionTicket{
func (h *cryptoSetup) getDataForSessionTicket() []byte {
return (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
}).Marshal()
}
return h.conn.GetSessionTicket(appData)
// GetSessionTicket generates a new session ticket.
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if h.tlsConf.SessionTicketsDisabled {
return nil, nil
}
if err := h.conn.SendSessionTicket(h.allow0RTT); err != nil {
return nil, err
}
ev := h.conn.NextEvent()
if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication {
panic("crypto/tls bug: where's my session ticket?")
}
ticket := ev.Data
if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent {
panic("crypto/tls bug: why more than one ticket?")
}
return ticket, nil
}
// accept0RTT is called for the server when receiving the client's session ticket.
@ -526,60 +396,12 @@ func (h *cryptoSetup) rejected0RTT() {
}
}
func (h *cryptoSetup) handlePostHandshakeMessage() {
// make sure the handshake has already completed
<-h.handshakeDone
done := make(chan struct{})
defer close(done)
// h.alertChan is an unbuffered channel.
// If an error occurs during conn.HandlePostHandshakeMessage,
// it will be sent on this channel.
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
alertChan := make(chan uint8, 1)
go func() {
<-h.isReadingHandshakeMessage
select {
case alert := <-h.alertChan:
alertChan <- alert
case <-done:
}
}()
if err := h.conn.HandlePostHandshakeMessage(); err != nil {
select {
case <-h.closeChan:
case alert := <-alertChan:
h.onError(alert, err.Error())
}
}
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
if !h.readFirstHandshakeMessage {
h.readFirstHandshakeMessage = true
} else {
select {
case h.isReadingHandshakeMessage <- struct{}{}:
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
select {
case msg := <-h.messageChan:
return msg, nil
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case qtls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
@ -587,16 +409,11 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
h.mutex.Unlock()
h.used0RTT.Store(true)
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
}
return
case qtls.EncryptionHandshake:
h.readEncLevel = protocol.EncryptionHandshake
case qtls.QUICEncryptionLevelHandshake:
h.handshakeOpener = newHandshakeOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
@ -606,8 +423,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.EncryptionApplication:
h.readEncLevel = protocol.Encryption1RTT
case qtls.QUICEncryptionLevelApplication:
h.aead.SetReadKey(suite, trafficSecret)
h.has1RTTOpener = true
if h.logger.Debug() {
@ -617,15 +433,18 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.runner.OnReceivedReadKeys()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case qtls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
@ -640,9 +459,9 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
}
// don't set used0RTT here. 0-RTT might still get rejected.
return
case qtls.EncryptionHandshake:
h.writeEncLevel = protocol.EncryptionHandshake
case qtls.QUICEncryptionLevelHandshake:
h.handshakeSealer = newHandshakeSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
@ -652,14 +471,15 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.EncryptionApplication:
h.writeEncLevel = protocol.Encryption1RTT
case qtls.QUICEncryptionLevelApplication:
h.aead.SetWriteKey(suite, trafficSecret)
h.has1RTTSealer = true
if h.logger.Debug() {
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.zeroRTTSealer != nil {
// Once we receive handshake keys, we know that 0-RTT was not rejected.
h.used0RTT.Store(true)
h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil {
@ -671,45 +491,30 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
}
h.mutex.Unlock()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error {
h.mutex.Lock()
defer h.mutex.Unlock()
//nolint:exhaustive // LS records can only be written for Initial and Handshake.
switch h.writeEncLevel {
case protocol.EncryptionInitial:
var str io.Writer
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
switch encLevel {
case qtls.QUICEncryptionLevelInitial:
// assume that the first WriteRecord call contains the ClientHello
n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true
close(h.clientHelloWrittenChan)
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.zeroRTTParametersChan <- h.zeroRTTParameters
} else {
h.logger.Debugf("Not doing 0-RTT.")
h.zeroRTTParametersChan <- nil
}
}
return n, err
case protocol.EncryptionHandshake:
return h.handshakeStream.Write(p)
str = h.initialStream
case qtls.QUICEncryptionLevelHandshake:
str = h.handshakeStream
case qtls.QUICEncryptionLevelApplication:
str = h.oneRTTStream
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
}
}
func (h *cryptoSetup) SendAlert(alert uint8) {
select {
case h.alertChan <- alert:
case <-h.closeChan:
// no need to send an alert when we've already closed
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
}
_, err := str.Write(p)
return err
}
// used a callback in the handshakeSealer and handshakeOpener
@ -722,6 +527,11 @@ func (h *cryptoSetup) dropInitialKeys() {
h.logger.Debugf("Dropping Initial keys.")
}
func (h *cryptoSetup) handshakeComplete() {
h.handshakeCompleteTime = time.Now()
h.runner.OnHandshakeComplete()
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
@ -839,5 +649,15 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return qtls.GetConnectionState(h.conn)
return ConnectionState{
ConnectionState: h.conn.ConnectionState(),
Used0RTT: h.used0RTT.Load(),
}
}
func wrapError(err error) error {
if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
return qerr.NewLocalCryptoError(uint8(alertErr), err.Error())
}
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
}

View file

@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@ -23,12 +22,10 @@ import (
. "github.com/onsi/gomega"
)
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
const (
typeClientHello = 1
typeNewSessionTicket = 4
)
type chunk struct {
data []byte
@ -80,54 +77,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}
})
It("returns Handshake() when an error occurs in qtls", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "local error: tls: unexpected message",
})))
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
handledMessage := make(chan struct{})
go func() {
defer GinkgoRecover()
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
close(handledMessage)
}()
Eventually(handledMessage).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
It("handles qtls errors occurring before during ClientHello generation", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
tlsConf := testdata.GetTLSConfig()
tlsConf.InsecureSkipVerify = true
@ -135,11 +85,10 @@ var _ = Describe("Crypto Setup TLS", func() {
cl, _ := NewCryptoSetupClient(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
runner,
NewMockHandshakeRunner(mockCtrl),
tlsConf,
false,
&utils.RTTStats{},
@ -148,32 +97,21 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cl.RunHandshake()
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: "tls: invalid NextProtos value",
})))
}))
})
It("errors when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
@ -184,90 +122,13 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
close(done)
}()
Expect(server.StartHandshake()).To(Succeed())
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake",
})))
// make the go routine return
Expect(server.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
var err error
Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
close(done)
}()
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
close(done)
}()
Expect(server.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
// wrong encryption level
err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})
Context("doing the handshake", func() {
@ -297,55 +158,32 @@ var _ = Describe("Crypto Setup TLS", func() {
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk,
) {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) {
Expect(client.StartHandshake()).To(Succeed())
Expect(server.StartHandshake()).To(Succeed())
for {
select {
case c := <-cChunkChan:
msgType := messageType(c.data[0])
finished := server.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeClientHello {
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
_, err := server.GetHandshakeOpener()
Expect(finished).To(Equal(err == nil))
} else {
Expect(finished).To(BeFalse())
Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed())
continue
default:
}
select {
case c := <-sChunkChan:
msgType := messageType(c.data[0])
finished := client.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeServerHello {
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
} else {
Expect(finished).To(BeFalse())
Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed())
continue
default:
}
case <-done: // handshake complete
return
// no more messages to send from client and server. Handshake complete?
break
}
}
}()
go func() {
defer GinkgoRecover()
defer close(done)
server.RunHandshake()
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
client.HandleMessage(ticket, protocol.Encryption1RTT)
Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
}
}()
client.RunHandshake()
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(
@ -359,15 +197,14 @@ var _ = Describe("Crypto Setup TLS", func() {
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
clientTransportParameters,
cRunner,
clientConf,
@ -383,7 +220,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
if serverTransportParameters.StatelessResetToken == nil {
var token protocol.StatelessResetToken
@ -392,9 +229,8 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
serverTransportParameters,
sRunner,
serverConf,
@ -462,9 +298,8 @@ var _ = Describe("Crypto Setup TLS", func() {
client, chChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
runner,
&tls.Config{InsecureSkipVerify: true},
@ -475,24 +310,15 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RunHandshake()
close(done)
}()
Expect(client.StartHandshake()).To(Succeed())
var ch chunk
Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(Receive(BeNil()))
// make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello))
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
Expect(client.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("receives transport parameters", func() {
@ -500,14 +326,14 @@ var _ = Describe("Crypto Setup TLS", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
cTransportParameters,
cRunner,
clientConf,
@ -521,6 +347,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
sRunner.EXPECT().OnHandshakeComplete()
sTransportParameters := &wire.TransportParameters{
@ -531,9 +358,8 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
sTransportParameters,
sRunner,
serverConf,
@ -561,13 +387,13 @@ var _ = Describe("Crypto Setup TLS", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
cRunner,
clientConf,
@ -581,14 +407,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
@ -608,25 +434,23 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake",
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.EncryptionHandshake)
err := client.HandleMessage(b, protocol.EncryptionHandshake)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
cRunner,
clientConf,
@ -640,14 +464,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
@ -667,13 +491,11 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
err := client.HandleMessage(b, protocol.Encryption1RTT)
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.Encryption1RTT)
})
It("uses session resumption", func() {
csc := mocktls.NewMockClientSessionCache(mockCtrl)
@ -785,7 +607,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
@ -840,7 +661,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}

View file

@ -6,8 +6,6 @@ import (
"strings"
"testing"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
@ -41,8 +39,8 @@ func splitHexString(s string) (slice []byte) {
return
}
var cipherSuites = []*qtls.CipherSuiteTLS13{
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256),
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384),
qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256),
var cipherSuites = []*cipherSuite{
getCipherSuite(tls.TLS_AES_128_GCM_SHA256),
getCipherSuite(tls.TLS_AES_256_GCM_SHA384),
getCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256),
}

View file

@ -10,7 +10,6 @@ import (
"golang.org/x/crypto/chacha20"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
)
type headerProtector interface {
@ -25,7 +24,7 @@ func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
return "quic hp"
}
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
hkdfLabel := hkdfHeaderProtectionLabel(v)
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
@ -45,7 +44,7 @@ type aesHeaderProtector struct {
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
block, err := aes.NewCipher(hpKey)
if err != nil {
@ -90,7 +89,7 @@ type chachaHeaderProtector struct {
var _ headerProtector = &chachaHeaderProtector{}
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
p := &chachaHeaderProtector{

View file

@ -7,7 +7,6 @@ import (
"golang.org/x/crypto/hkdf"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
)
var (
@ -29,12 +28,7 @@ func getSalt(v protocol.VersionNumber) []byte {
return quicSaltV1
}
var initialSuite = &qtls.CipherSuiteTLS13{
ID: tls.TLS_AES_128_GCM_SHA256,
KeyLen: 16,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA256,
}
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
@ -50,8 +44,8 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
encrypter := initialSuite.AEAD(myKey, myIV)
decrypter := initialSuite.AEAD(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))

View file

@ -1,12 +1,12 @@
package handshake
import (
"crypto/tls"
"errors"
"io"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/wire"
)
@ -22,9 +22,6 @@ var (
ErrDecryptionFailed = errors.New("decryption failed")
)
// ConnectionState contains information about the state of the connection.
type ConnectionState = qtls.ConnectionState
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
@ -56,28 +53,26 @@ type ShortHeaderSealer interface {
KeyPhase() protocol.KeyPhaseBit
}
// A tlsExtensionHandler sends and received the QUIC TLS extension.
type tlsExtensionHandler interface {
GetExtensions(msgType uint8) []qtls.Extension
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
TransportParameters() <-chan []byte
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnError(error)
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type ConnectionState struct {
tls.ConnectionState
Used0RTT bool
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake()
StartHandshake() error
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) bool
HandleMessage([]byte, protocol.EncryptionLevel) error
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
ConnectionState() ConnectionState

View file

@ -47,18 +47,6 @@ func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0)
}
// OnError mocks base method.
func (m *MockHandshakeRunner) OnError(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnError", arg0)
}
// OnError indicates an expected call of OnError.
func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0)
}
// OnHandshakeComplete mocks base method.
func (m *MockHandshakeRunner) OnHandshakeComplete() {
m.ctrl.T.Helper()
@ -82,3 +70,15 @@ func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *g
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0)
}
// OnReceivedReadKeys mocks base method.
func (m *MockHandshakeRunner) OnReceivedReadKeys() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnReceivedReadKeys")
}
// OnReceivedReadKeys indicates an expected call of OnReceivedReadKeys.
func (mr *MockHandshakeRunnerMockRecorder) OnReceivedReadKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedReadKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedReadKeys))
}

View file

@ -10,7 +10,7 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
const sessionTicketRevision = 2
const sessionTicketRevision = 3
type sessionTicket struct {
Parameters *wire.TransportParameters

View file

@ -1,61 +0,0 @@
package handshake
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
)
const quicTLSExtensionType = 0x39
type extensionHandler struct {
ourParams []byte
paramsChan chan []byte
extensionType uint16
perspective protocol.Perspective
}
var _ tlsExtensionHandler = &extensionHandler{}
// newExtensionHandler creates a new extension handler
func newExtensionHandler(params []byte, pers protocol.Perspective) tlsExtensionHandler {
return &extensionHandler{
ourParams: params,
paramsChan: make(chan []byte),
perspective: pers,
extensionType: quicTLSExtensionType,
}
}
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
return nil
}
return []qtls.Extension{{
Type: h.extensionType,
Data: h.ourParams,
}}
}
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
return
}
var data []byte
for _, ext := range exts {
if ext.Type == h.extensionType {
data = ext.Data
break
}
}
h.paramsChan <- data
}
func (h *extensionHandler) TransportParameters() <-chan []byte {
return h.paramsChan
}

View file

@ -1,165 +0,0 @@
package handshake
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TLS Extension Handler, for the server", func() {
var (
handlerServer tlsExtensionHandler
handlerClient tlsExtensionHandler
)
JustBeforeEach(func() {
handlerServer = newExtensionHandler([]byte("foobar"), protocol.PerspectiveServer)
handlerClient = newExtensionHandler([]byte("raboof"), protocol.PerspectiveClient)
})
Context("for the server", func() {
Context("sending", func() {
It("only adds TransportParameters for the Encrypted Extensions", func() {
// test 2 other handshake types
Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty())
Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty())
})
It("adds TransportParameters to the EncryptedExtensions message", func() {
exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions))
Expect(exts).To(HaveLen(1))
Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType))
Expect(exts[0].Data).To(Equal([]byte("foobar")))
})
})
Context("receiving", func() {
var chExts []qtls.Extension
JustBeforeEach(func() {
chExts = handlerClient.GetExtensions(uint8(typeClientHello))
Expect(chExts).To(HaveLen(1))
})
It("sends the extension on the channel", func() {
go func() {
defer GinkgoRecover()
handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts)
}()
var data []byte
Eventually(handlerServer.TransportParameters()).Should(Receive(&data))
Expect(data).To(Equal([]byte("raboof")))
})
It("sends nil on the channel if the extension is missing", func() {
go func() {
defer GinkgoRecover()
handlerServer.ReceivedExtensions(uint8(typeClientHello), nil)
}()
var data []byte
Eventually(handlerServer.TransportParameters()).Should(Receive(&data))
Expect(data).To(BeEmpty())
})
It("ignores extensions with different code points", func() {
go func() {
defer GinkgoRecover()
exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}}
handlerServer.ReceivedExtensions(uint8(typeClientHello), exts)
}()
var data []byte
Eventually(handlerServer.TransportParameters()).Should(Receive())
Expect(data).To(BeEmpty())
})
It("ignores extensions that are not sent with the ClientHello", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handlerServer.ReceivedExtensions(uint8(typeFinished), chExts)
close(done)
}()
Consistently(handlerServer.TransportParameters()).ShouldNot(Receive())
Eventually(done).Should(BeClosed())
})
})
})
Context("for the client", func() {
Context("sending", func() {
It("only adds TransportParameters for the Encrypted Extensions", func() {
// test 2 other handshake types
Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty())
Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty())
})
It("adds TransportParameters to the ClientHello message", func() {
exts := handlerClient.GetExtensions(uint8(typeClientHello))
Expect(exts).To(HaveLen(1))
Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType))
Expect(exts[0].Data).To(Equal([]byte("raboof")))
})
})
Context("receiving", func() {
var chExts []qtls.Extension
JustBeforeEach(func() {
chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions))
Expect(chExts).To(HaveLen(1))
})
It("sends the extension on the channel", func() {
go func() {
defer GinkgoRecover()
handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts)
}()
var data []byte
Eventually(handlerClient.TransportParameters()).Should(Receive(&data))
Expect(data).To(Equal([]byte("foobar")))
})
It("sends nil on the channel if the extension is missing", func() {
go func() {
defer GinkgoRecover()
handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil)
}()
var data []byte
Eventually(handlerClient.TransportParameters()).Should(Receive(&data))
Expect(data).To(BeEmpty())
})
It("ignores extensions with different code points", func() {
go func() {
defer GinkgoRecover()
exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}}
handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts)
}()
var data []byte
Eventually(handlerClient.TransportParameters()).Should(Receive())
Expect(data).To(BeEmpty())
})
It("ignores extensions that are not sent with the EncryptedExtensions", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handlerClient.ReceivedExtensions(uint8(typeFinished), chExts)
close(done)
}()
Consistently(handlerClient.TransportParameters()).ShouldNot(Receive())
Eventually(done).Should(BeClosed())
})
})
})
})

View file

@ -10,7 +10,6 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
@ -24,7 +23,7 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
var FirstKeyUpdateInterval uint64 = 100
type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13
suite *cipherSuite
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
@ -121,7 +120,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
// SetReadKey sets the read key.
// For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
if a.suite == nil {
@ -135,7 +134,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [
// SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
if a.suite == nil {
@ -146,7 +145,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
}
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
a.nonceBuf = make([]byte, aead.NonceSize())
a.aeadOverhead = aead.Overhead()
a.suite = suite

View file

@ -10,7 +10,6 @@ import (
gomock "github.com/golang/mock/gomock"
handshake "github.com/quic-go/quic-go/internal/handshake"
protocol "github.com/quic-go/quic-go/internal/protocol"
qtls "github.com/quic-go/quic-go/internal/qtls"
)
// MockCryptoSetup is a mock of CryptoSetup interface.
@ -63,10 +62,10 @@ func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
}
// ConnectionState mocks base method.
func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState {
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(qtls.ConnectionState)
ret0, _ := ret[0].(handshake.ConnectionState)
return ret0
}
@ -212,10 +211,10 @@ func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call {
}
// HandleMessage mocks base method.
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(bool)
ret0, _ := ret[0].(error)
return ret0
}
@ -225,18 +224,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// RunHandshake mocks base method.
func (m *MockCryptoSetup) RunHandshake() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RunHandshake")
}
// RunHandshake indicates an expected call of RunHandshake.
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
}
// SetHandshakeConfirmed mocks base method.
func (m *MockCryptoSetup) SetHandshakeConfirmed() {
m.ctrl.T.Helper()
@ -262,3 +249,17 @@ func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *go
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0)
}
// StartHandshake mocks base method.
func (m *MockCryptoSetup) StartHandshake() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartHandshake")
ret0, _ := ret[0].(error)
return ret0
}
// StartHandshake indicates an expected call of StartHandshake.
func (mr *MockCryptoSetupMockRecorder) StartHandshake() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).StartHandshake))
}

View file

@ -40,7 +40,7 @@ func (e TransportErrorCode) Message() string {
if !e.IsCryptoError() {
return ""
}
return qtls.Alert(e - 0x100).Error()
return qtls.AlertError(e - 0x100).Error()
}
func (e TransportErrorCode) String() string {

View file

@ -0,0 +1,66 @@
//go:build go1.21
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"unsafe"
)
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}

View file

@ -0,0 +1,52 @@
//go:build go1.21
package qtls
import (
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/testdata"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Setting the Cipher Suite", func() {
for _, cs := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_AES_256_GCM_SHA384} {
cs := cs
It(fmt.Sprintf("selects %s", tls.CipherSuiteName(cs)), func() {
reset := SetCipherSuite(cs)
defer reset()
ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
_, err = conn.Read(make([]byte, 10))
Expect(err).ToNot(HaveOccurred())
Expect(conn.(*tls.Conn).ConnectionState().CipherSuite).To(Equal(cs))
}()
conn, err := tls.Dial(
"tcp4",
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
&tls.Config{RootCAs: testdata.GetRootCA()},
)
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().CipherSuite).To(Equal(cs))
Expect(conn.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
}
})

View file

@ -0,0 +1,61 @@
//go:build go1.21
package qtls
import (
"crypto/tls"
)
type clientSessionCache struct {
getData func() []byte
setData func([]byte)
wrapped tls.ClientSessionCache
}
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
if cs == nil {
c.wrapped.Put(key, nil)
return
}
ticket, state, err := cs.ResumptionState()
if err != nil || state == nil {
c.wrapped.Put(key, cs)
return
}
state.Extra = append(state.Extra, addExtraPrefix(c.getData()))
newCS, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error. Just save the original state.
c.wrapped.Put(key, cs)
return
}
c.wrapped.Put(key, newCS)
}
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
cs, ok := c.wrapped.Get(key)
if !ok || cs == nil {
return cs, ok
}
ticket, state, err := cs.ResumptionState()
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
// restore QUIC transport parameters and RTT stored in state.Extra
if extra := findExtraData(state.Extra); extra != nil {
c.setData(extra)
}
session, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
return session, true
}

View file

@ -0,0 +1,82 @@
//go:build go1.21
package qtls
import (
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/testdata"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Client Session Cache", func() {
It("adds data to and restores data from a session ticket", func() {
ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
_, err = conn.Read(make([]byte, 10))
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
}
}()
restored := make(chan []byte, 1)
clientConf := &tls.Config{
RootCAs: testdata.GetRootCA(),
ClientSessionCache: &clientSessionCache{
wrapped: tls.NewLRUClientSessionCache(10),
getData: func() []byte { return []byte("session") },
setData: func(data []byte) { restored <- data },
},
}
conn, err := tls.Dial(
"tcp4",
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
clientConf,
)
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().DidResume).To(BeFalse())
Expect(restored).To(HaveLen(0))
_, err = conn.Read(make([]byte, 10))
Expect(err).ToNot(HaveOccurred())
Expect(conn.Close()).To(Succeed())
// make sure the cache can deal with nonsensical inputs
clientConf.ClientSessionCache.Put("foo", nil)
clientConf.ClientSessionCache.Put("bar", &tls.ClientSessionState{})
conn, err = tls.Dial(
"tcp4",
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
clientConf,
)
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().DidResume).To(BeTrue())
var restoredData []byte
Expect(restored).To(Receive(&restoredData))
Expect(restoredData).To(Equal([]byte("session")))
Expect(conn.Close()).To(Succeed())
Expect(ln.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
})

View file

@ -1,145 +0,0 @@
//go:build go1.19 && !go1.20
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"net"
"unsafe"
"github.com/quic-go/qtls-go1-19"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains information about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}
//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-19.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}

View file

@ -1,101 +1,97 @@
//go:build go1.20
//go:build go1.20 && !go1.21
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"net"
"unsafe"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/qtls-go1-20"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains information about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
QUICConn = qtls.QUICConn
QUICConfig = qtls.QUICConfig
QUICEvent = qtls.QUICEvent
QUICEventKind = qtls.QUICEventKind
QUICEncryptionLevel = qtls.QUICEncryptionLevel
AlertError = qtls.AlertError
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial
QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
const (
QUICNoEvent = qtls.QUICNoEvent
QUICSetReadSecret = qtls.QUICSetReadSecret
QUICSetWriteSecret = qtls.QUICSetWriteSecret
QUICWriteData = qtls.QUICWriteData
QUICTransportParameters = qtls.QUICTransportParameters
QUICTransportParametersRequired = qtls.QUICTransportParametersRequired
QUICRejectedEarlyData = qtls.QUICRejectedEarlyData
QUICHandshakeDone = qtls.QUICHandshakeDone
)
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) {
qtls.InitSessionTicketKeys(conf.TLSConfig)
conf.TLSConfig = conf.TLSConfig.Clone()
conf.TLSConfig.MinVersion = tls.VersionTLS13
conf.ExtraConfig = &qtls.ExtraConfig{
Enable0RTT: enable0RTT,
Accept0RTT: accept0RTT,
GetAppDataForSessionTicket: getDataForSessionTicket,
}
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) {
conf.ExtraConfig = &qtls.ExtraConfig{
GetAppDataForSessionState: getDataForSessionState,
SetAppDataFromSessionState: setDataFromSessionState,
}
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
func QUICServer(config *QUICConfig) *QUICConn {
return qtls.QUICServer(config)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
func QUICClient(config *QUICConfig) *QUICConn {
return qtls.QUICClient(config)
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel {
switch e {
case protocol.EncryptionInitial:
return qtls.QUICEncryptionLevelInitial
case protocol.EncryptionHandshake:
return qtls.QUICEncryptionLevelHandshake
case protocol.Encryption1RTT:
return qtls.QUICEncryptionLevelApplication
case protocol.Encryption0RTT:
return qtls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-20.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel {
switch e {
case qtls.QUICEncryptionLevelInitial:
return protocol.EncryptionInitial
case qtls.QUICEncryptionLevelHandshake:
return protocol.EncryptionHandshake
case qtls.QUICEncryptionLevelApplication:
return protocol.Encryption1RTT
case qtls.QUICEncryptionLevelEarly:
return protocol.Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}

View file

@ -0,0 +1,28 @@
//go:build !go1.21
package qtls
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/qtls-go1-20"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Go 1.20", func() {
It("converts to qtls.EncryptionLevel", func() {
Expect(ToTLSEncryptionLevel(protocol.EncryptionInitial)).To(Equal(qtls.QUICEncryptionLevelInitial))
Expect(ToTLSEncryptionLevel(protocol.EncryptionHandshake)).To(Equal(qtls.QUICEncryptionLevelHandshake))
Expect(ToTLSEncryptionLevel(protocol.Encryption1RTT)).To(Equal(qtls.QUICEncryptionLevelApplication))
Expect(ToTLSEncryptionLevel(protocol.Encryption0RTT)).To(Equal(qtls.QUICEncryptionLevelEarly))
})
It("converts from qtls.EncryptionLevel", func() {
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelInitial)).To(Equal(protocol.EncryptionInitial))
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelHandshake)).To(Equal(protocol.EncryptionHandshake))
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelApplication)).To(Equal(protocol.Encryption1RTT))
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelEarly)).To(Equal(protocol.Encryption0RTT))
})
})

View file

@ -2,4 +2,153 @@
package qtls
var _ int = "The version of quic-go you're using can't be built on Go 1.21 yet. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
import (
"bytes"
"crypto/tls"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
)
type (
QUICConn = tls.QUICConn
QUICConfig = tls.QUICConfig
QUICEvent = tls.QUICEvent
QUICEventKind = tls.QUICEventKind
QUICEncryptionLevel = tls.QUICEncryptionLevel
AlertError = tls.AlertError
)
const (
QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial
QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication
)
const (
QUICNoEvent = tls.QUICNoEvent
QUICSetReadSecret = tls.QUICSetReadSecret
QUICSetWriteSecret = tls.QUICSetWriteSecret
QUICWriteData = tls.QUICWriteData
QUICTransportParameters = tls.QUICTransportParameters
QUICTransportParametersRequired = tls.QUICTransportParametersRequired
QUICRejectedEarlyData = tls.QUICRejectedEarlyData
QUICHandshakeDone = tls.QUICHandshakeDone
)
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) {
conf := qconf.TLSConfig
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC transport parameters if this is a 0-RTT packet.
// TODO(#3853): also save the RTT for non-0-RTT tickets
if state.EarlyData {
state.Extra = append(state.Extra, addExtraPrefix(getData()))
}
if origWrapSession != nil {
return origWrapSession(cs, state)
}
b, err := conf.EncryptTicket(cs, state)
return b, err
}
origUnwrapSession := conf.UnwrapSession
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
// However, using 0-RTT is only possible with the first session ticket.
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
var unwrapCount int
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
unwrapCount++
var state *tls.SessionState
var err error
if origUnwrapSession != nil {
state, err = origUnwrapSession(identity, connState)
} else {
state, err = conf.DecryptTicket(identity, connState)
}
if err != nil || state == nil {
return nil, err
}
if state.EarlyData {
extra := findExtraData(state.Extra)
if unwrapCount == 1 && extra != nil { // first session ticket
state.EarlyData = accept0RTT(extra)
} else { // subsequent session ticket, can't be used for 0-RTT
state.EarlyData = false
}
}
return state, nil
}
}
func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) {
conf := qconf.TLSConfig
if conf.ClientSessionCache != nil {
origCache := conf.ClientSessionCache
conf.ClientSessionCache = &clientSessionCache{
wrapped: origCache,
getData: getData,
setData: setData,
}
}
}
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
switch e {
case protocol.EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case protocol.EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case protocol.Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case protocol.Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return protocol.EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return protocol.EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return protocol.Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return protocol.Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}
const extraPrefix = "quic-go1"
func addExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}

View file

@ -0,0 +1,55 @@
//go:build go1.21
package qtls
import (
"crypto/tls"
"github.com/quic-go/quic-go/internal/protocol"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Go 1.21", func() {
It("converts to tls.EncryptionLevel", func() {
Expect(ToTLSEncryptionLevel(protocol.EncryptionInitial)).To(Equal(tls.QUICEncryptionLevelInitial))
Expect(ToTLSEncryptionLevel(protocol.EncryptionHandshake)).To(Equal(tls.QUICEncryptionLevelHandshake))
Expect(ToTLSEncryptionLevel(protocol.Encryption1RTT)).To(Equal(tls.QUICEncryptionLevelApplication))
Expect(ToTLSEncryptionLevel(protocol.Encryption0RTT)).To(Equal(tls.QUICEncryptionLevelEarly))
})
It("converts from tls.EncryptionLevel", func() {
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelInitial)).To(Equal(protocol.EncryptionInitial))
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelHandshake)).To(Equal(protocol.EncryptionHandshake))
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelApplication)).To(Equal(protocol.Encryption1RTT))
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelEarly)).To(Equal(protocol.Encryption0RTT))
})
Context("setting up a tls.Config for the client", func() {
It("sets up a session cache if there's one present on the config", func() {
csc := tls.NewLRUClientSessionCache(1)
conf := &QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}}
SetupConfigForClient(conf, nil, nil)
Expect(conf.TLSConfig.ClientSessionCache).ToNot(BeNil())
Expect(conf.TLSConfig.ClientSessionCache).ToNot(Equal(csc))
})
It("doesn't set up a session cache if there's none present on the config", func() {
conf := &QUICConfig{TLSConfig: &tls.Config{}}
SetupConfigForClient(conf, nil, nil)
Expect(conf.TLSConfig.ClientSessionCache).To(BeNil())
})
})
Context("setting up a tls.Config for the server", func() {
It("sets the minimum TLS version to TLS 1.3", func() {
orig := &tls.Config{MinVersion: tls.VersionTLS12}
conf := &QUICConfig{TLSConfig: orig}
SetupConfigForServer(conf, false, nil, nil)
Expect(conf.TLSConfig.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
// check that the original config wasn't modified
Expect(orig.MinVersion).To(BeEquivalentTo(tls.VersionTLS12))
})
})
})

View file

@ -1,4 +1,4 @@
//go:build !go1.19
//go:build !go1.20
package qtls

View file

@ -1,17 +0,0 @@
package qtls
import (
"crypto/tls"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("qtls wrapper", func() {
It("gets cipher suites", func() {
for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} {
cs := CipherSuiteTLS13ByID(id)
Expect(cs.ID).To(Equal(id))
}
})
})

View file

@ -31,6 +31,7 @@ func GetTLSConfig() *tls.Config {
panic(err)
}
return &tls.Config{
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{cert},
}
}

View file

@ -35,10 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
}
// HandleMessage mocks base method.
func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(bool)
ret0, _ := ret[0].(error)
return ret0
}

View file

@ -156,6 +156,8 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
}
@ -172,6 +174,8 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}