mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
use the new crypto/tls QUIC Transport (#3860)
This commit is contained in:
parent
4998733ae1
commit
3d89e545d3
55 changed files with 2197 additions and 1509 deletions
|
@ -1,11 +1,5 @@
|
|||
version: 2.1
|
||||
executors:
|
||||
test-go119:
|
||||
docker:
|
||||
- image: "cimg/go:1.19"
|
||||
environment:
|
||||
runrace: true
|
||||
TIMESCALE_FACTOR: 3
|
||||
test-go120:
|
||||
docker:
|
||||
- image: "cimg/go:1.20"
|
||||
|
@ -15,7 +9,7 @@ executors:
|
|||
|
||||
jobs:
|
||||
"test": &test
|
||||
executor: test-go119
|
||||
executor: test-go120
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
|
@ -39,14 +33,10 @@ jobs:
|
|||
- run:
|
||||
name: "Run version negotiation tests with qlog"
|
||||
command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation -- -qlog
|
||||
go119:
|
||||
<<: *test
|
||||
go120:
|
||||
<<: *test
|
||||
executor: test-go120
|
||||
|
||||
workflows:
|
||||
workflow:
|
||||
jobs:
|
||||
- go119
|
||||
- go120
|
||||
|
|
2
.github/workflows/cross-compile.yml
vendored
2
.github/workflows/cross-compile.yml
vendored
|
@ -4,7 +4,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go: [ "1.19.x", "1.20.x" ]
|
||||
go: [ "1.20.x", "1.21.0-rc.2" ]
|
||||
runs-on: ${{ fromJSON(vars['CROSS_COMPILE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
|
||||
name: "Cross Compilation (Go ${{matrix.go}})"
|
||||
steps:
|
||||
|
|
2
.github/workflows/integration.yml
vendored
2
.github/workflows/integration.yml
vendored
|
@ -5,7 +5,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go: [ "1.19.x", "1.20.x" ]
|
||||
go: [ "1.20.x", "1.21.0-rc.2" ]
|
||||
runs-on: ${{ fromJSON(vars['INTEGRATION_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
|
||||
env:
|
||||
DEBUG: false # set this to true to export qlogs and save them as artifacts
|
||||
|
|
2
.github/workflows/unit.yml
vendored
2
.github/workflows/unit.yml
vendored
|
@ -7,7 +7,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
os: [ "ubuntu", "windows", "macos" ]
|
||||
go: [ "1.19.x", "1.20.x" ]
|
||||
go: [ "1.20.x", "1.21.0-rc.2" ]
|
||||
runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }}
|
||||
name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }})
|
||||
steps:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
run:
|
||||
skip-files:
|
||||
- internal/handshake/cipher_suite.go
|
||||
linters-settings:
|
||||
depguard:
|
||||
type: blacklist
|
||||
|
|
|
@ -220,7 +220,8 @@ quic-go always aims to support the latest two Go releases.
|
|||
|
||||
### Dependency on forked crypto/tls
|
||||
|
||||
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
|
||||
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
|
||||
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ coverage:
|
|||
- http3/gzip_reader.go
|
||||
- interop/
|
||||
- internal/ackhandler/packet_linkedlist.go
|
||||
- internal/handshake/cipher_suite.go
|
||||
- internal/utils/byteinterval_linkedlist.go
|
||||
- internal/utils/newconnectionid_linkedlist.go
|
||||
- internal/utils/packetinterval_linkedlist.go
|
||||
|
|
|
@ -52,7 +52,7 @@ type streamManager interface {
|
|||
}
|
||||
|
||||
type cryptoStreamHandler interface {
|
||||
RunHandshake()
|
||||
StartHandshake() error
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||
SetHandshakeConfirmed()
|
||||
|
@ -98,15 +98,15 @@ type connRunner interface {
|
|||
|
||||
type handshakeRunner struct {
|
||||
onReceivedParams func(*wire.TransportParameters)
|
||||
onError func(error)
|
||||
onReceivedReadKeys func()
|
||||
dropKeys func(protocol.EncryptionLevel)
|
||||
onHandshakeComplete func()
|
||||
}
|
||||
|
||||
func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) }
|
||||
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
|
||||
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
|
||||
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
|
||||
func (r *handshakeRunner) OnReceivedReadKeys() { r.onReceivedReadKeys() }
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
|
@ -329,14 +329,13 @@ var newConnection = func(
|
|||
cs := handshake.NewCryptoSetupServer(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
s.oneRTTStream,
|
||||
clientDestConnID,
|
||||
conn.LocalAddr(),
|
||||
conn.RemoteAddr(),
|
||||
params,
|
||||
&handshakeRunner{
|
||||
onReceivedParams: s.handleTransportParameters,
|
||||
onError: s.closeLocal,
|
||||
dropKeys: s.dropEncryptionLevel,
|
||||
onReceivedReadKeys: s.receivedReadKeys,
|
||||
onHandshakeComplete: func() {
|
||||
runner.Retire(clientDestConnID)
|
||||
close(s.handshakeCompleteChan)
|
||||
|
@ -418,6 +417,7 @@ var newClientConnection = func(
|
|||
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
|
||||
initialStream := newCryptoStream()
|
||||
handshakeStream := newCryptoStream()
|
||||
oneRTTStream := newCryptoStream()
|
||||
params := &wire.TransportParameters{
|
||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
|
@ -448,14 +448,13 @@ var newClientConnection = func(
|
|||
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
destConnID,
|
||||
conn.LocalAddr(),
|
||||
conn.RemoteAddr(),
|
||||
params,
|
||||
&handshakeRunner{
|
||||
onReceivedParams: s.handleTransportParameters,
|
||||
onError: s.closeLocal,
|
||||
dropKeys: s.dropEncryptionLevel,
|
||||
onReceivedReadKeys: s.receivedReadKeys,
|
||||
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
|
||||
},
|
||||
tlsConf,
|
||||
|
@ -467,7 +466,7 @@ var newClientConnection = func(
|
|||
)
|
||||
s.clientHelloWritten = clientHelloWritten
|
||||
s.cryptoStreamHandler = cs
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
|
||||
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
|
||||
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
|
||||
if len(tlsConf.ServerName) > 0 {
|
||||
|
@ -530,11 +529,9 @@ func (s *connection) run() error {
|
|||
|
||||
s.timer = *newTimer()
|
||||
|
||||
handshaking := make(chan struct{})
|
||||
go func() {
|
||||
defer close(handshaking)
|
||||
s.cryptoStreamHandler.RunHandshake()
|
||||
}()
|
||||
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
if err := s.sendQueue.Run(); err != nil {
|
||||
s.destroyImpl(err)
|
||||
|
@ -686,7 +683,6 @@ runLoop:
|
|||
}
|
||||
|
||||
s.cryptoStreamHandler.Close()
|
||||
<-handshaking
|
||||
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
|
||||
s.handleCloseError(&closeErr)
|
||||
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
|
||||
|
@ -717,7 +713,9 @@ func (s *connection) supportsDatagrams() bool {
|
|||
func (s *connection) ConnectionState() ConnectionState {
|
||||
s.connStateMutex.Lock()
|
||||
defer s.connStateMutex.Unlock()
|
||||
s.connState.TLS = s.cryptoStreamHandler.ConnectionState()
|
||||
cs := s.cryptoStreamHandler.ConnectionState()
|
||||
s.connState.TLS = cs.ConnectionState
|
||||
s.connState.Used0RTT = cs.Used0RTT
|
||||
return s.connState
|
||||
}
|
||||
|
||||
|
@ -786,7 +784,7 @@ func (s *connection) handleHandshakeComplete() {
|
|||
if err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
if ticket != nil {
|
||||
if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled
|
||||
s.oneRTTStream.Write(ticket)
|
||||
for s.oneRTTStream.HasData() {
|
||||
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
|
||||
|
@ -1378,17 +1376,14 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame
|
|||
}
|
||||
|
||||
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
|
||||
if err != nil {
|
||||
return err
|
||||
return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
|
||||
}
|
||||
if encLevelChanged {
|
||||
|
||||
func (s *connection) receivedReadKeys() {
|
||||
// Queue all packets for decryption that have been undecryptable so far.
|
||||
s.undecryptablePacketsToProcess = s.undecryptablePackets
|
||||
s.undecryptablePackets = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error {
|
||||
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
|
||||
|
@ -1629,11 +1624,15 @@ func (s *connection) handleCloseError(closeErr *closeError) {
|
|||
}
|
||||
|
||||
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
|
||||
s.sentPacketHandler.DropPackets(encLevel)
|
||||
s.receivedPacketHandler.DropPackets(encLevel)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedEncryptionLevel(encLevel)
|
||||
}
|
||||
s.sentPacketHandler.DropPackets(encLevel)
|
||||
s.receivedPacketHandler.DropPackets(encLevel)
|
||||
if err := s.cryptoStreamManager.Drop(encLevel); err != nil {
|
||||
s.closeLocal(err)
|
||||
return
|
||||
}
|
||||
if encLevel == protocol.Encryption0RTT {
|
||||
s.streamsMap.ResetFor0RTT()
|
||||
if err := s.connFlowController.Reset(); err != nil {
|
||||
|
@ -1817,6 +1816,9 @@ func (s *connection) sendPackets(now time.Time) error {
|
|||
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset})
|
||||
}
|
||||
s.windowUpdateQueue.QueueAll()
|
||||
if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
|
||||
s.queueControlFrame(cf)
|
||||
}
|
||||
|
||||
if !s.handshakeConfirmed {
|
||||
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
|
|
|
@ -119,7 +119,7 @@ var _ = Describe("Connection", func() {
|
|||
&protocol.DefaultConnectionIDGenerator{},
|
||||
protocol.StatelessResetToken{},
|
||||
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
|
||||
nil, // tls.Config
|
||||
&tls.Config{},
|
||||
tokenGenerator,
|
||||
false,
|
||||
tracer,
|
||||
|
@ -357,7 +357,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
Expect(conn.run()).To(MatchError(expectedErr))
|
||||
}()
|
||||
Expect(conn.handleFrame(&wire.ConnectionCloseFrame{
|
||||
|
@ -385,7 +385,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
Expect(conn.run()).To(MatchError(testErr))
|
||||
}()
|
||||
ccf := &wire.ConnectionCloseFrame{
|
||||
|
@ -432,7 +432,7 @@ var _ = Describe("Connection", func() {
|
|||
runConn := func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
runErr <- conn.run()
|
||||
}()
|
||||
Eventually(areConnsRunning).Should(BeTrue())
|
||||
|
@ -811,7 +811,7 @@ var _ = Describe("Connection", func() {
|
|||
packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
expectReplaceWithClosed()
|
||||
|
@ -853,7 +853,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
|
||||
|
@ -888,7 +888,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
|
||||
|
@ -913,7 +913,7 @@ var _ = Describe("Connection", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
err := conn.run()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
|
||||
|
@ -937,7 +937,7 @@ var _ = Describe("Connection", func() {
|
|||
runErr := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
runErr <- conn.run()
|
||||
}()
|
||||
expectReplaceWithClosed()
|
||||
|
@ -961,7 +961,7 @@ var _ = Describe("Connection", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
err := conn.run()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
|
||||
|
@ -1197,7 +1197,7 @@ var _ = Describe("Connection", func() {
|
|||
runConn := func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
close(connDone)
|
||||
}()
|
||||
|
@ -1415,7 +1415,7 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1439,7 +1439,7 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1463,7 +1463,7 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1479,7 +1479,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Send(gomock.Any(), gomock.Any())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1496,7 +1496,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Send(gomock.Any(), gomock.Any())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1514,7 +1514,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Send(gomock.Any(), gomock.Any())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1540,7 +1540,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1562,7 +1562,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1580,7 +1580,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Available().Return(available)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1602,7 +1602,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().WouldBlock().AnyTimes()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
|
||||
|
@ -1633,7 +1633,7 @@ var _ = Describe("Connection", func() {
|
|||
sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} })
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
available := make(chan struct{}, 1)
|
||||
|
@ -1664,7 +1664,7 @@ var _ = Describe("Connection", func() {
|
|||
// don't EXPECT any calls to mconn.Write()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending() // no packet will get sent
|
||||
|
@ -1687,7 +1687,7 @@ var _ = Describe("Connection", func() {
|
|||
packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), conn.version).Return(shortHeaderPacket{PacketNumber: 1}, getPacketBuffer(), nil)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
conn.scheduleSending()
|
||||
|
@ -1734,7 +1734,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
// don't EXPECT any calls to mconn.Write()
|
||||
|
@ -1768,7 +1768,7 @@ var _ = Describe("Connection", func() {
|
|||
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
Eventually(written).Should(BeClosed())
|
||||
|
@ -1832,7 +1832,7 @@ var _ = Describe("Connection", func() {
|
|||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
|
||||
|
@ -1864,7 +1864,7 @@ var _ = Describe("Connection", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
<-finishHandshake
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().StartHandshake()
|
||||
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||
cryptoSetup.EXPECT().GetSessionTicket()
|
||||
close(conn.handshakeCompleteChan)
|
||||
|
@ -1894,7 +1894,7 @@ var _ = Describe("Connection", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
<-finishHandshake
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().StartHandshake()
|
||||
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
|
||||
close(conn.handshakeCompleteChan)
|
||||
|
@ -1941,7 +1941,7 @@ var _ = Describe("Connection", func() {
|
|||
tracer.EXPECT().Close()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().StartHandshake()
|
||||
conn.run()
|
||||
}()
|
||||
handshakeCtx := conn.HandshakeComplete()
|
||||
|
@ -1974,7 +1974,7 @@ var _ = Describe("Connection", func() {
|
|||
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().StartHandshake()
|
||||
cryptoSetup.EXPECT().SetHandshakeConfirmed()
|
||||
cryptoSetup.EXPECT().GetSessionTicket()
|
||||
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
|
||||
|
@ -1997,7 +1997,7 @@ var _ = Describe("Connection", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
Expect(conn.run()).To(Succeed())
|
||||
close(done)
|
||||
}()
|
||||
|
@ -2017,7 +2017,7 @@ var _ = Describe("Connection", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
err := conn.run()
|
||||
Expect(err).To(MatchError(&qerr.ApplicationError{
|
||||
ErrorCode: 0x1337,
|
||||
|
@ -2069,7 +2069,7 @@ var _ = Describe("Connection", func() {
|
|||
runConn := func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
}
|
||||
|
@ -2171,7 +2171,7 @@ var _ = Describe("Connection", func() {
|
|||
)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
err := conn.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
@ -2196,7 +2196,7 @@ var _ = Describe("Connection", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
err := conn.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
@ -2229,7 +2229,7 @@ var _ = Describe("Connection", func() {
|
|||
// and not on the last network activity
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
|
||||
|
@ -2256,7 +2256,7 @@ var _ = Describe("Connection", func() {
|
|||
conn.handshakeComplete = false
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||
err := conn.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
|
@ -2285,7 +2285,7 @@ var _ = Describe("Connection", func() {
|
|||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
|
||||
close(conn.handshakeCompleteChan)
|
||||
|
@ -2305,7 +2305,7 @@ var _ = Describe("Connection", func() {
|
|||
conn.idleTimeout = 30 * time.Second
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
|
||||
|
@ -2336,7 +2336,7 @@ var _ = Describe("Connection", func() {
|
|||
pto := conn.rttStats.PTO(true)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
|
||||
close(conn.handshakeCompleteChan)
|
||||
|
@ -2508,7 +2508,7 @@ var _ = Describe("Client Connection", func() {
|
|||
conn.unpacker = unpacker
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
conn.run()
|
||||
}()
|
||||
newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7})
|
||||
|
@ -2588,7 +2588,7 @@ var _ = Describe("Client Connection", func() {
|
|||
tracer.EXPECT().ClosedConnection(gomock.Any())
|
||||
tracer.EXPECT().Close()
|
||||
running := make(chan struct{})
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() {
|
||||
cryptoSetup.EXPECT().StartHandshake().Do(func() {
|
||||
close(running)
|
||||
conn.closeLocal(errors.New("early error"))
|
||||
})
|
||||
|
@ -2641,7 +2641,7 @@ var _ = Describe("Client Connection", func() {
|
|||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
errChan <- conn.run()
|
||||
}()
|
||||
connRunner.EXPECT().Remove(srcConnID)
|
||||
|
@ -2666,7 +2666,7 @@ var _ = Describe("Client Connection", func() {
|
|||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
errChan <- conn.run()
|
||||
}()
|
||||
connRunner.EXPECT().Remove(srcConnID).MaxTimes(1)
|
||||
|
@ -2774,7 +2774,7 @@ var _ = Describe("Client Connection", func() {
|
|||
closed = false
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
|
||||
errChan <- conn.run()
|
||||
close(errChan)
|
||||
}()
|
||||
|
|
|
@ -71,17 +71,9 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
|||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) GetCryptoData() []byte {
|
||||
if len(s.msgBuf) < 4 {
|
||||
return nil
|
||||
}
|
||||
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
|
||||
if len(s.msgBuf) < msgLen {
|
||||
return nil
|
||||
}
|
||||
msg := make([]byte, msgLen)
|
||||
copy(msg, s.msgBuf[:msgLen])
|
||||
s.msgBuf = s.msgBuf[msgLen:]
|
||||
return msg
|
||||
b := s.msgBuf
|
||||
s.msgBuf = nil
|
||||
return b
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) Finish() error {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
type cryptoDataHandler interface {
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) error
|
||||
}
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
|
@ -33,7 +33,7 @@ func newCryptoStreamManager(
|
|||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
var str cryptoStream
|
||||
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
|
||||
switch encLevel {
|
||||
|
@ -44,18 +44,39 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
|
|||
case protocol.Encryption1RTT:
|
||||
str = m.oneRTTStream
|
||||
default:
|
||||
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
for {
|
||||
data := str.GetCryptoData()
|
||||
if data == nil {
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
|
||||
return true, str.Finish()
|
||||
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
|
||||
if !m.oneRTTStream.HasData() {
|
||||
return nil
|
||||
}
|
||||
return m.oneRTTStream.PopCryptoFrame(maxSize)
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
|
||||
//nolint:exhaustive // 1-RTT keys should never get dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return m.initialStream.Finish()
|
||||
case protocol.EncryptionHandshake:
|
||||
return m.handshakeStream.Finish()
|
||||
case protocol.Encryption0RTT:
|
||||
return nil
|
||||
default:
|
||||
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -35,9 +32,7 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
initialStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
|
||||
initialStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial)
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed())
|
||||
})
|
||||
|
||||
It("passes messages to the handshake stream", func() {
|
||||
|
@ -46,9 +41,7 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
|
||||
handshakeStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake)
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
})
|
||||
|
||||
It("passes messages to the 1-RTT stream", func() {
|
||||
|
@ -57,9 +50,7 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
|
||||
oneRTTStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT)
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)).To(Succeed())
|
||||
})
|
||||
|
||||
It("doesn't call the message handler, if there's no message", func() {
|
||||
|
@ -67,9 +58,7 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
handshakeStream.EXPECT().HandleCryptoFrame(cf)
|
||||
handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle
|
||||
// don't EXPECT any calls to HandleMessage()
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
})
|
||||
|
||||
It("processes all messages", func() {
|
||||
|
@ -80,40 +69,26 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
handshakeStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake)
|
||||
cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake)
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
})
|
||||
|
||||
It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
|
||||
cf := &wire.CryptoFrame{Data: []byte("foobar")}
|
||||
gomock.InOrder(
|
||||
handshakeStream.EXPECT().HandleCryptoFrame(cf),
|
||||
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
|
||||
handshakeStream.EXPECT().Finish(),
|
||||
)
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns errors that occur when finishing a stream", func() {
|
||||
testErr := errors.New("test error")
|
||||
cf := &wire.CryptoFrame{Data: []byte("foobar")}
|
||||
gomock.InOrder(
|
||||
handshakeStream.EXPECT().HandleCryptoFrame(cf),
|
||||
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
|
||||
handshakeStream.EXPECT().Finish().Return(testErr),
|
||||
)
|
||||
_, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).To(MatchError(err))
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
})
|
||||
|
||||
It("errors for unknown encryption levels", func() {
|
||||
_, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42)
|
||||
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("received CRYPTO frame with unexpected encryption level"))
|
||||
})
|
||||
|
||||
It("drops Initial", func() {
|
||||
initialStream.EXPECT().Finish()
|
||||
Expect(csm.Drop(protocol.EncryptionInitial)).To(Succeed())
|
||||
})
|
||||
|
||||
It("drops Handshake", func() {
|
||||
handshakeStream.EXPECT().Finish()
|
||||
Expect(csm.Drop(protocol.EncryptionHandshake)).To(Succeed())
|
||||
})
|
||||
|
||||
It("no-ops when dropping 0-RTT", func() {
|
||||
Expect(csm.Drop(protocol.Encryption0RTT)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -12,16 +11,6 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func createHandshakeMessage(len int) []byte {
|
||||
msg := make([]byte, 4+len)
|
||||
rand.Read(msg[:1]) // random message type
|
||||
msg[1] = uint8(len >> 16)
|
||||
msg[2] = uint8(len >> 8)
|
||||
msg[3] = uint8(len)
|
||||
rand.Read(msg[4:])
|
||||
return msg
|
||||
}
|
||||
|
||||
var _ = Describe("Crypto Stream", func() {
|
||||
var str cryptoStream
|
||||
|
||||
|
@ -31,21 +20,11 @@ var _ = Describe("Crypto Stream", func() {
|
|||
|
||||
Context("handling incoming data", func() {
|
||||
It("handles in-order CRYPTO frames", func() {
|
||||
msg := createHandshakeMessage(6)
|
||||
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.GetCryptoData()).To(Equal(msg))
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})).To(Succeed())
|
||||
Expect(str.GetCryptoData()).To(Equal([]byte("foo")))
|
||||
Expect(str.GetCryptoData()).To(BeNil())
|
||||
})
|
||||
|
||||
It("handles multiple messages in one CRYPTO frame", func() {
|
||||
msg1 := createHandshakeMessage(6)
|
||||
msg2 := createHandshakeMessage(10)
|
||||
msg := append(append([]byte{}, msg1...), msg2...)
|
||||
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.GetCryptoData()).To(Equal(msg1))
|
||||
Expect(str.GetCryptoData()).To(Equal(msg2))
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})).To(Succeed())
|
||||
Expect(str.GetCryptoData()).To(Equal([]byte("bar")))
|
||||
Expect(str.GetCryptoData()).To(BeNil())
|
||||
})
|
||||
|
||||
|
@ -59,42 +38,17 @@ var _ = Describe("Crypto Stream", func() {
|
|||
}))
|
||||
})
|
||||
|
||||
It("handles messages split over multiple CRYPTO frames", func() {
|
||||
msg := createHandshakeMessage(6)
|
||||
err := str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Data: msg[:4],
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.GetCryptoData()).To(BeNil())
|
||||
err = str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Offset: 4,
|
||||
Data: msg[4:],
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.GetCryptoData()).To(Equal(msg))
|
||||
Expect(str.GetCryptoData()).To(BeNil())
|
||||
})
|
||||
|
||||
It("handles out-of-order CRYPTO frames", func() {
|
||||
msg := createHandshakeMessage(6)
|
||||
err := str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Offset: 4,
|
||||
Data: msg[4:],
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.GetCryptoData()).To(BeNil())
|
||||
err = str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Data: msg[:4],
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.GetCryptoData()).To(Equal(msg))
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Offset: 3, Data: []byte("bar")})).To(Succeed())
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})).To(Succeed())
|
||||
Expect(str.GetCryptoData()).To(Equal([]byte("foobar")))
|
||||
Expect(str.GetCryptoData()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("finishing", func() {
|
||||
It("errors if there's still data to read after finishing", func() {
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Data: createHandshakeMessage(5),
|
||||
Data: []byte("foobar"),
|
||||
Offset: 10,
|
||||
})).To(Succeed())
|
||||
Expect(str.Finish()).To(MatchError(&qerr.TransportError{
|
||||
|
@ -120,7 +74,7 @@ var _ = Describe("Crypto Stream", func() {
|
|||
It("rejects new crypto data after finishing", func() {
|
||||
Expect(str.Finish()).To(Succeed())
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Data: createHandshakeMessage(5),
|
||||
Data: []byte("foo"),
|
||||
})).To(MatchError(&qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received crypto data after change of encryption level",
|
||||
|
@ -128,15 +82,14 @@ var _ = Describe("Crypto Stream", func() {
|
|||
})
|
||||
|
||||
It("ignores crypto data below the maximum offset received before finishing", func() {
|
||||
msg := createHandshakeMessage(15)
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Data: msg,
|
||||
Data: []byte("foobar"),
|
||||
})).To(Succeed())
|
||||
Expect(str.GetCryptoData()).To(Equal(msg))
|
||||
Expect(str.GetCryptoData()).To(Equal([]byte("foobar")))
|
||||
Expect(str.Finish()).To(Succeed())
|
||||
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
|
||||
Offset: protocol.ByteCount(len(msg) - 6),
|
||||
Data: []byte("foobar"),
|
||||
Offset: 2,
|
||||
Data: []byte("foo"),
|
||||
})).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -43,26 +43,24 @@ func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */)
|
|||
type handshakeRunner interface {
|
||||
OnReceivedParams(*wire.TransportParameters)
|
||||
OnHandshakeComplete()
|
||||
OnError(error)
|
||||
OnReceivedReadKeys()
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
type runner struct {
|
||||
client, server *handshake.CryptoSetup
|
||||
handshakeComplete chan<- struct{}
|
||||
}
|
||||
|
||||
var _ handshakeRunner = &runner{}
|
||||
|
||||
func newRunner(client, server *handshake.CryptoSetup) *runner {
|
||||
return &runner{client: client, server: server}
|
||||
func newRunner(handshakeComplete chan<- struct{}) *runner {
|
||||
return &runner{handshakeComplete: handshakeComplete}
|
||||
}
|
||||
|
||||
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
|
||||
func (r *runner) OnHandshakeComplete() {}
|
||||
func (r *runner) OnError(err error) {
|
||||
(*r.client).Close()
|
||||
(*r.server).Close()
|
||||
log.Fatal("runner error:", err)
|
||||
func (r *runner) OnReceivedReadKeys() {}
|
||||
func (r *runner) OnHandshakeComplete() {
|
||||
close(r.handshakeComplete)
|
||||
}
|
||||
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
|
||||
|
||||
|
@ -71,16 +69,16 @@ const alpn = "fuzz"
|
|||
func main() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
var client, server handshake.CryptoSetup
|
||||
runner := newRunner(&client, &server)
|
||||
clientHandshakeCompleted := make(chan struct{})
|
||||
client, _ = handshake.NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
runner,
|
||||
newRunner(clientHandshakeCompleted),
|
||||
&tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
ServerName: "localhost",
|
||||
NextProtos: []string{alpn},
|
||||
RootCAs: testdata.GetRootCA(),
|
||||
|
@ -96,14 +94,14 @@ func main() {
|
|||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
config := testdata.GetTLSConfig()
|
||||
config.NextProtos = []string{alpn}
|
||||
serverHandshakeCompleted := make(chan struct{})
|
||||
server = handshake.NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
runner,
|
||||
newRunner(serverHandshakeCompleted),
|
||||
config,
|
||||
false,
|
||||
utils.NewRTTStats(),
|
||||
|
@ -112,17 +110,13 @@ func main() {
|
|||
protocol.Version1,
|
||||
)
|
||||
|
||||
serverHandshakeCompleted := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverHandshakeCompleted)
|
||||
server.RunHandshake()
|
||||
}()
|
||||
if err := client.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
clientHandshakeCompleted := make(chan struct{})
|
||||
go func() {
|
||||
defer close(clientHandshakeCompleted)
|
||||
client.RunHandshake()
|
||||
}()
|
||||
if err := server.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -137,10 +131,14 @@ messageLoop:
|
|||
select {
|
||||
case c := <-cChunkChan:
|
||||
messages = append(messages, c.data)
|
||||
server.HandleMessage(c.data, c.encLevel)
|
||||
if err := server.HandleMessage(c.data, c.encLevel); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case c := <-sChunkChan:
|
||||
messages = append(messages, c.data)
|
||||
client.HandleMessage(c.data, c.encLevel)
|
||||
if err := client.HandleMessage(c.data, c.encLevel); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case <-done:
|
||||
break messageLoop
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"log"
|
||||
"math"
|
||||
mrand "math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/fuzzing/internal/helper"
|
||||
|
@ -157,39 +156,24 @@ func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */)
|
|||
type handshakeRunner interface {
|
||||
OnReceivedParams(*wire.TransportParameters)
|
||||
OnHandshakeComplete()
|
||||
OnError(error)
|
||||
OnReceivedReadKeys()
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
type runner struct {
|
||||
sync.Mutex
|
||||
errored bool
|
||||
client, server *handshake.CryptoSetup
|
||||
handshakeComplete chan<- struct{}
|
||||
}
|
||||
|
||||
var _ handshakeRunner = &runner{}
|
||||
|
||||
func newRunner(client, server *handshake.CryptoSetup) *runner {
|
||||
return &runner{client: client, server: server}
|
||||
func newRunner(handshakeComplete chan<- struct{}) *runner {
|
||||
return &runner{handshakeComplete: handshakeComplete}
|
||||
}
|
||||
|
||||
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
|
||||
func (r *runner) OnHandshakeComplete() {}
|
||||
func (r *runner) OnError(err error) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
if r.errored {
|
||||
return
|
||||
}
|
||||
r.errored = true
|
||||
(*r.client).Close()
|
||||
(*r.server).Close()
|
||||
}
|
||||
|
||||
func (r *runner) Errored() bool {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return r.errored
|
||||
func (r *runner) OnReceivedReadKeys() {}
|
||||
func (r *runner) OnHandshakeComplete() {
|
||||
close(r.handshakeComplete)
|
||||
}
|
||||
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
|
||||
|
||||
|
@ -270,6 +254,7 @@ func Fuzz(data []byte) int {
|
|||
}
|
||||
|
||||
clientConf := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
ServerName: "localhost",
|
||||
NextProtos: []string{alpn},
|
||||
RootCAs: certPool,
|
||||
|
@ -287,6 +272,7 @@ func Fuzz(data []byte) int {
|
|||
|
||||
func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
|
||||
serverConf := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
NextProtos: []string{alpn},
|
||||
SessionTicketKey: sessionTicketKey,
|
||||
|
@ -373,15 +359,14 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
|||
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
var client, server handshake.CryptoSetup
|
||||
runner := newRunner(&client, &server)
|
||||
clientHandshakeCompleted := make(chan struct{})
|
||||
client, _ = handshake.NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
clientTP,
|
||||
runner,
|
||||
newRunner(clientHandshakeCompleted),
|
||||
clientConf,
|
||||
enable0RTTClient,
|
||||
utils.NewRTTStats(),
|
||||
|
@ -390,15 +375,15 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
|||
protocol.Version1,
|
||||
)
|
||||
|
||||
serverHandshakeCompleted := make(chan struct{})
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
server = handshake.NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
serverTP,
|
||||
runner,
|
||||
newRunner(serverHandshakeCompleted),
|
||||
serverConf,
|
||||
enable0RTTServer,
|
||||
utils.NewRTTStats(),
|
||||
|
@ -411,17 +396,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
|||
return -1
|
||||
}
|
||||
|
||||
serverHandshakeCompleted := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverHandshakeCompleted)
|
||||
server.RunHandshake()
|
||||
}()
|
||||
if err := client.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
clientHandshakeCompleted := make(chan struct{})
|
||||
go func() {
|
||||
defer close(clientHandshakeCompleted)
|
||||
client.RunHandshake()
|
||||
}()
|
||||
if err := server.StartHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -441,7 +422,9 @@ messageLoop:
|
|||
b = data
|
||||
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
|
||||
}
|
||||
server.HandleMessage(b, encLevel)
|
||||
if err := server.HandleMessage(b, encLevel); err != nil {
|
||||
break messageLoop
|
||||
}
|
||||
case c := <-sChunkChan:
|
||||
b := c.data
|
||||
encLevel := c.encLevel
|
||||
|
@ -450,11 +433,10 @@ messageLoop:
|
|||
b = data
|
||||
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
|
||||
}
|
||||
client.HandleMessage(b, encLevel)
|
||||
case <-done: // test done
|
||||
if err := client.HandleMessage(b, encLevel); err != nil {
|
||||
break messageLoop
|
||||
}
|
||||
if runner.Errored() {
|
||||
case <-done: // test done
|
||||
break messageLoop
|
||||
}
|
||||
}
|
||||
|
@ -462,9 +444,6 @@ messageLoop:
|
|||
<-done
|
||||
_ = client.ConnectionState()
|
||||
_ = server.ConnectionState()
|
||||
if runner.Errored() {
|
||||
return 0
|
||||
}
|
||||
|
||||
sealer, err := client.Get1RTTSealer()
|
||||
if err != nil {
|
||||
|
|
3
go.mod
3
go.mod
|
@ -8,8 +8,7 @@ require (
|
|||
github.com/onsi/ginkgo/v2 v2.9.5
|
||||
github.com/onsi/gomega v1.27.6
|
||||
github.com/quic-go/qpack v0.4.0
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2
|
||||
github.com/quic-go/qtls-go1-20 v0.3.0
|
||||
golang.org/x/crypto v0.4.0
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
|
||||
golang.org/x/net v0.10.0
|
||||
|
|
6
go.sum
6
go.sum
|
@ -90,10 +90,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q
|
|||
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
|
@ -402,7 +401,7 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
|
|||
return nil, newConnError(ErrCodeGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
|
||||
connState := conn.ConnectionState().TLS
|
||||
res := &http.Response{
|
||||
Proto: "HTTP/3.0",
|
||||
ProtoMajor: 3,
|
||||
|
|
|
@ -577,7 +577,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
|||
return newStreamError(ErrCodeGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
connState := conn.ConnectionState().TLS.ConnectionState
|
||||
connState := conn.ConnectionState().TLS
|
||||
req.TLS = &connState
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
body := newRequestBody(newStream(str, onFrameError))
|
||||
|
|
|
@ -926,7 +926,7 @@ var _ = Describe("Server", func() {
|
|||
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
})
|
||||
|
||||
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
|
||||
|
@ -954,7 +954,7 @@ var _ = Describe("Server", func() {
|
|||
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
})
|
||||
|
||||
It("works if GetConfigForClient returns a nil tls.Config", func() {
|
||||
|
@ -967,7 +967,7 @@ var _ = Describe("Server", func() {
|
|||
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
})
|
||||
|
||||
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
|
||||
|
@ -985,7 +985,7 @@ var _ = Describe("Server", func() {
|
|||
c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3))
|
||||
// check that the original config was not modified
|
||||
Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
|
||||
})
|
||||
|
|
|
@ -136,10 +136,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q
|
|||
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
|
||||
|
@ -185,7 +183,6 @@ golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnf
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||
|
|
|
@ -198,7 +198,10 @@ var _ = Describe("Handshake tests", func() {
|
|||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
|
||||
Expect(transportErr.Error()).To(ContainSubstring("tls: bad certificate"))
|
||||
Expect(transportErr.Error()).To(Or(
|
||||
ContainSubstring("tls: certificate required"),
|
||||
ContainSubstring("tls: bad certificate"),
|
||||
))
|
||||
})
|
||||
|
||||
It("uses the ServerName in the tls.Config", func() {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
|
@ -14,16 +14,15 @@ import (
|
|||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
mutex sync.Mutex
|
||||
cache map[string]*tls.ClientSessionState
|
||||
cache tls.ClientSessionCache
|
||||
|
||||
gets chan<- string
|
||||
puts chan<- string
|
||||
}
|
||||
|
||||
func newClientSessionCache(gets, puts chan<- string) *clientSessionCache {
|
||||
func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
|
||||
return &clientSessionCache{
|
||||
cache: make(map[string]*tls.ClientSessionState),
|
||||
cache: cache,
|
||||
gets: gets,
|
||||
puts: puts,
|
||||
}
|
||||
|
@ -32,29 +31,25 @@ func newClientSessionCache(gets, puts chan<- string) *clientSessionCache {
|
|||
var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
|
||||
session, ok := c.cache.Get(sessionKey)
|
||||
c.gets <- sessionKey
|
||||
c.mutex.Lock()
|
||||
session, ok := c.cache[sessionKey]
|
||||
c.mutex.Unlock()
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
||||
c.cache.Put(sessionKey, cs)
|
||||
c.puts <- sessionKey
|
||||
c.mutex.Lock()
|
||||
c.cache[sessionKey] = cs
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
var _ = Describe("TLS session resumption", func() {
|
||||
It("uses session resumption", func() {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
cache := newClientSessionCache(gets, puts)
|
||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn, err := quic.DialAddr(
|
||||
|
@ -96,7 +91,7 @@ var _ = Describe("TLS session resumption", func() {
|
|||
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
cache := newClientSessionCache(gets, puts)
|
||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.ClientSessionCache = cache
|
||||
conn, err := quic.DialAddr(
|
||||
|
@ -109,7 +104,9 @@ var _ = Describe("TLS session resumption", func() {
|
|||
Consistently(puts).ShouldNot(Receive())
|
||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
serverConn, err := server.Accept(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
serverConn, err := server.Accept(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||
|
||||
|
|
804
integrationtests/self/zero_rtt_oldgo_test.go
Normal file
804
integrationtests/self/zero_rtt_oldgo_test.go
Normal file
|
@ -0,0 +1,804 @@
|
|||
//go:build !go1.21
|
||||
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("0-RTT", func() {
|
||||
rtt := scaleDuration(5 * time.Millisecond)
|
||||
|
||||
runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) {
|
||||
var num0RTTPackets uint32 // to be used as an atomic
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
|
||||
for len(data) > 0 {
|
||||
if !wire.IsLongHeaderPacket(data[0]) {
|
||||
break
|
||||
}
|
||||
hdr, _, rest, err := wire.ParsePacket(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
atomic.AddUint32(&num0RTTPackets, 1)
|
||||
break
|
||||
}
|
||||
data = rest
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return proxy, &num0RTTPackets
|
||||
}
|
||||
|
||||
dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) {
|
||||
tlsConf := getTLSConfig()
|
||||
if serverConf == nil {
|
||||
serverConf = getQuicConfig(nil)
|
||||
}
|
||||
serverConf.Allow0RTT = true
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
serverConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
// dial the first connection in order to receive a session ticket
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
<-conn.Context().Done()
|
||||
}()
|
||||
|
||||
clientConf := getTLSClientConfig()
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
clientConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(100), gets, puts)
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
clientConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(puts).Should(Receive())
|
||||
// received the session ticket. We're done here.
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
return tlsConf, clientConf
|
||||
}
|
||||
|
||||
transfer0RTTData := func(
|
||||
ln *quic.EarlyListener,
|
||||
proxyPort int,
|
||||
connIDLen int,
|
||||
clientTLSConf *tls.Config,
|
||||
clientConf *quic.Config,
|
||||
testdata []byte, // data to transfer
|
||||
) {
|
||||
// accept the second connection, and receive the data sent in 0-RTT
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testdata))
|
||||
Expect(str.Close()).To(Succeed())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
<-conn.Context().Done()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if clientConf == nil {
|
||||
clientConf = getQuicConfig(nil)
|
||||
}
|
||||
var conn quic.EarlyConnection
|
||||
if connIDLen == 0 {
|
||||
var err error
|
||||
conn, err = quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxyPort),
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer tr.Close()
|
||||
conn, err = tr.DialEarly(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(testdata)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
<-conn.HandshakeComplete()
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
|
||||
conn.CloseWithError(0, "")
|
||||
Eventually(done).Should(BeClosed())
|
||||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
}
|
||||
|
||||
check0RTTRejected := func(
|
||||
ln *quic.EarlyListener,
|
||||
proxyPort int,
|
||||
clientConf *tls.Config,
|
||||
) {
|
||||
conn, err := quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxyPort),
|
||||
clientConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(make([]byte, 3000))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
|
||||
|
||||
// make sure the server doesn't process the data
|
||||
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
|
||||
defer cancel()
|
||||
serverConn, err := ln.Accept(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
|
||||
_, err = serverConn.AcceptUniStream(ctx)
|
||||
Expect(err).To(Equal(context.DeadlineExceeded))
|
||||
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
|
||||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
}
|
||||
|
||||
// can be used to extract 0-RTT from a packetTracer
|
||||
get0RTTPackets := func(packets []packet) []protocol.PacketNumber {
|
||||
var zeroRTTPackets []protocol.PacketNumber
|
||||
for _, p := range packets {
|
||||
if p.hdr.Type == protocol.PacketType0RTT {
|
||||
zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber)
|
||||
}
|
||||
}
|
||||
return zeroRTTPackets
|
||||
}
|
||||
|
||||
for _, l := range []int{0, 15} {
|
||||
connIDLen := l
|
||||
|
||||
It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
|
||||
tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(
|
||||
ln,
|
||||
proxy.LocalPort(),
|
||||
connIDLen,
|
||||
clientTLSConf,
|
||||
getQuicConfig(nil),
|
||||
PRData,
|
||||
)
|
||||
|
||||
var numNewConnIDs int
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, f := range p.frames {
|
||||
if _, ok := f.(*logging.NewConnectionIDFrame); ok {
|
||||
numNewConnIDs++
|
||||
}
|
||||
}
|
||||
}
|
||||
if connIDLen == 0 {
|
||||
Expect(numNewConnIDs).To(BeZero())
|
||||
} else {
|
||||
Expect(numNewConnIDs).ToNot(BeZero())
|
||||
}
|
||||
|
||||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0)))
|
||||
})
|
||||
}
|
||||
|
||||
// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
|
||||
It("waits for a connection until the handshake is done", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
zeroRTTData := GeneratePRData(5 << 10)
|
||||
oneRTTData := PRData
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
// now accept the second connection, and receive the 0-RTT data
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(zeroRTTData))
|
||||
str, err = conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err = io.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(oneRTTData))
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
}()
|
||||
|
||||
proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
clientConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
firstStr, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = firstStr.Write(zeroRTTData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(firstStr.Close()).To(Succeed())
|
||||
|
||||
// wait for the handshake to complete
|
||||
Eventually(conn.HandshakeComplete()).Should(BeClosed())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
<-conn.Context().Done()
|
||||
|
||||
// check that 0-RTT packets only contain STREAM frames for the first stream
|
||||
var num0RTT int
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
if p.hdr.Header.Type != protocol.PacketType0RTT {
|
||||
continue
|
||||
}
|
||||
for _, f := range p.frames {
|
||||
sf, ok := f.(*logging.StreamFrame)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
num0RTT++
|
||||
Expect(sf.StreamID).To(Equal(firstStr.StreamID()))
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
})
|
||||
|
||||
It("transfers 0-RTT data, when 0-RTT packets are lost", func() {
|
||||
var (
|
||||
num0RTTPackets uint32 // to be used as an atomic
|
||||
num0RTTDropped uint32
|
||||
)
|
||||
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
|
||||
if wire.IsLongHeaderPacket(data[0]) {
|
||||
hdr, _, _, err := wire.ParsePacket(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
atomic.AddUint32(&num0RTTPackets, 1)
|
||||
}
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
DropPacket: func(_ quicproxy.Direction, data []byte) bool {
|
||||
if !wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
hdr, _, _, err := wire.ParsePacket(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
// drop 25% of the 0-RTT packets
|
||||
drop := mrand.Intn(4) == 0
|
||||
if drop {
|
||||
atomic.AddUint32(&num0RTTDropped, 1)
|
||||
}
|
||||
return drop
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
num0RTT := atomic.LoadUint32(&num0RTTPackets)
|
||||
numDropped := atomic.LoadUint32(&num0RTTDropped)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped)
|
||||
Expect(numDropped).ToNot(BeZero())
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("retransmits all 0-RTT data when the server performs a Retry", func() {
|
||||
var mutex sync.Mutex
|
||||
var firstConnID, secondConnID *protocol.ConnectionID
|
||||
var firstCounter, secondCounter protocol.ByteCount
|
||||
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
|
||||
for len(data) > 0 {
|
||||
hdr, _, rest, err := wire.ParsePacket(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
data = rest
|
||||
if hdr.Type == protocol.PacketType0RTT {
|
||||
n += hdr.Length - 16 /* AEAD tag */
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
|
||||
connID, err := wire.ParseConnectionID(data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 {
|
||||
if firstConnID == nil {
|
||||
firstConnID = &connID
|
||||
firstCounter += zeroRTTBytes
|
||||
} else if firstConnID != nil && *firstConnID == connID {
|
||||
Expect(secondConnID).To(BeNil())
|
||||
firstCounter += zeroRTTBytes
|
||||
} else if secondConnID == nil {
|
||||
secondConnID = &connID
|
||||
secondCounter += zeroRTTBytes
|
||||
} else if secondConnID != nil && *secondConnID == connID {
|
||||
secondCounter += zeroRTTBytes
|
||||
} else {
|
||||
Fail("received 3 connection IDs on 0-RTT packets")
|
||||
}
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra
|
||||
Expect(secondCounter).To(BeNumerically("~", firstCounter, 20))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5))
|
||||
Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5)))
|
||||
})
|
||||
|
||||
It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
|
||||
const maxStreams = 1
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams,
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams + 1,
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
clientConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
// The client remembers the old limit and refuses to open a new stream.
|
||||
_, err = conn.OpenUniStream()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("too many open streams"))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, err = conn.OpenUniStreamSync(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the server's stream limit decreased", func() {
|
||||
const maxStreams = 42
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: maxStreams,
|
||||
}))
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: maxStreams - 1,
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
|
||||
|
||||
// The client should send 0-RTT packets, but the server doesn't process them.
|
||||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the ALPN changed", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
// now close the listener and dial new connection with a different ALPN
|
||||
clientConf.NextProtos = []string{"new-alpn"}
|
||||
tlsConf.NextProtos = []string{"new-alpn"}
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
|
||||
|
||||
// The client should send 0-RTT packets, but the server doesn't process them.
|
||||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the application doesn't allow it", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
// now close the listener and dial new connection with a different ALPN
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: false, // application rejects 0-RTT
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
|
||||
|
||||
// The client should send 0-RTT packets, but the server doesn't process them.
|
||||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
|
||||
DescribeTable("flow control limits",
|
||||
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
||||
tracer := newPacketTracer()
|
||||
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
|
||||
addFlowControlLimit(firstConf, 3)
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
|
||||
|
||||
secondConf := getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
})
|
||||
addFlowControlLimit(secondConf, 100)
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
secondConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
clientConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
written := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(written)
|
||||
_, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
|
||||
Eventually(written).Should(BeClosed())
|
||||
|
||||
serverConn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rstr, err := serverConn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(rstr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
|
||||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
|
||||
var processedFirst bool
|
||||
for _, p := range tracer.getRcvdLongHeaderPackets() {
|
||||
for _, f := range p.frames {
|
||||
if sf, ok := f.(*logging.StreamFrame); ok {
|
||||
if !processedFirst {
|
||||
// The first STREAM should have been sent in a 0-RTT packet.
|
||||
// Due to the flow control limit, the STREAM frame was limit to the first 3 bytes.
|
||||
Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT))
|
||||
Expect(sf.Length).To(BeEquivalentTo(3))
|
||||
processedFirst = true
|
||||
} else {
|
||||
Fail("STREAM was shouldn't have been sent in 0-RTT")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }),
|
||||
Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }),
|
||||
)
|
||||
|
||||
for _, l := range []int{0, 15} {
|
||||
connIDLen := l
|
||||
|
||||
It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
// now dial new connection with different transport parameters
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: 1,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
clientConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// The client remembers that it was allowed to open 2 uni-directional streams.
|
||||
firstStr, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
written := make(chan struct{}, 2)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer func() { written <- struct{}{} }()
|
||||
_, err := firstStr.Write([]byte("first flight"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
secondStr, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer func() { written <- struct{}{} }()
|
||||
_, err := secondStr.Write([]byte("first flight"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, err = conn.AcceptStream(ctx)
|
||||
Expect(err).To(MatchError(quic.Err0RTTRejected))
|
||||
Eventually(written).Should(Receive())
|
||||
Eventually(written).Should(Receive())
|
||||
_, err = firstStr.Write([]byte("foobar"))
|
||||
Expect(err).To(MatchError(quic.Err0RTTRejected))
|
||||
_, err = conn.OpenUniStream()
|
||||
Expect(err).To(MatchError(quic.Err0RTTRejected))
|
||||
|
||||
_, err = conn.AcceptStream(ctx)
|
||||
Expect(err).To(Equal(quic.Err0RTTRejected))
|
||||
|
||||
newConn := conn.NextConnection()
|
||||
str, err := newConn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = newConn.OpenUniStream()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("too many open streams"))
|
||||
_, err = str.Write([]byte("second flight"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
|
||||
// The client should send 0-RTT packets, but the server doesn't process them.
|
||||
num0RTT := atomic.LoadUint32(num0RTTPackets)
|
||||
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
|
||||
Expect(num0RTT).ToNot(BeZero())
|
||||
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
|
||||
})
|
||||
}
|
||||
|
||||
It("queues 0-RTT packets, if the Initial is delayed", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: ln.Addr().String(),
|
||||
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
|
||||
return rtt/2 + rtt
|
||||
}
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
|
||||
})
|
||||
})
|
|
@ -1,3 +1,5 @@
|
|||
//go:build go1.21
|
||||
|
||||
package self_test
|
||||
|
||||
import (
|
||||
|
@ -21,6 +23,36 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type metadataClientSessionCache struct {
|
||||
toAdd []byte
|
||||
restored func([]byte)
|
||||
|
||||
cache tls.ClientSessionCache
|
||||
}
|
||||
|
||||
func (m metadataClientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||
session, ok := m.cache.Get(key)
|
||||
if !ok || session == nil {
|
||||
return session, ok
|
||||
}
|
||||
ticket, state, err := session.ResumptionState()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(state.Extra).To(HaveLen(2)) // ours, and the quic-go's
|
||||
m.restored(state.Extra[1])
|
||||
session, err = tls.NewResumptionState(ticket, state)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionState) {
|
||||
ticket, state, err := session.ResumptionState()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
state.Extra = append(state.Extra, m.toAdd)
|
||||
session, err = tls.NewResumptionState(ticket, state)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
m.cache.Put(key, session)
|
||||
}
|
||||
|
||||
var _ = Describe("0-RTT", func() {
|
||||
rtt := scaleDuration(5 * time.Millisecond)
|
||||
|
||||
|
@ -49,15 +81,14 @@ var _ = Describe("0-RTT", func() {
|
|||
return proxy, &num0RTTPackets
|
||||
}
|
||||
|
||||
dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) {
|
||||
tlsConf := getTLSConfig()
|
||||
dialAndReceiveSessionTicket := func(serverTLSConf *tls.Config, serverConf *quic.Config, clientTLSConf *tls.Config) {
|
||||
if serverConf == nil {
|
||||
serverConf = getQuicConfig(nil)
|
||||
}
|
||||
serverConf.Allow0RTT = true
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
serverTLSConf,
|
||||
serverConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -80,14 +111,16 @@ var _ = Describe("0-RTT", func() {
|
|||
<-conn.Context().Done()
|
||||
}()
|
||||
|
||||
clientConf := getTLSClientConfig()
|
||||
gets := make(chan string, 100)
|
||||
puts := make(chan string, 100)
|
||||
clientConf.ClientSessionCache = newClientSessionCache(gets, puts)
|
||||
cache := clientTLSConf.ClientSessionCache
|
||||
if cache == nil {
|
||||
cache = tls.NewLRUClientSessionCache(100)
|
||||
}
|
||||
clientTLSConf.ClientSessionCache = newClientSessionCache(cache, make(chan string, 100), puts)
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
clientConf,
|
||||
clientTLSConf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -95,7 +128,6 @@ var _ = Describe("0-RTT", func() {
|
|||
// received the session ticket. We're done here.
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
return tlsConf, clientConf
|
||||
}
|
||||
|
||||
transfer0RTTData := func(
|
||||
|
@ -118,7 +150,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testdata))
|
||||
Expect(str.Close()).To(Succeed())
|
||||
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
<-conn.Context().Done()
|
||||
close(done)
|
||||
}()
|
||||
|
@ -162,7 +194,7 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
<-conn.HandshakeComplete()
|
||||
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
|
||||
conn.CloseWithError(0, "")
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -186,14 +218,14 @@ var _ = Describe("0-RTT", func() {
|
|||
_, err = str.Write(make([]byte, 3000))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeFalse())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeFalse())
|
||||
|
||||
// make sure the server doesn't process the data
|
||||
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond))
|
||||
defer cancel()
|
||||
serverConn, err := ln.Accept(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeFalse())
|
||||
Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
|
||||
_, err = serverConn.AcceptUniStream(ctx)
|
||||
Expect(err).To(Equal(context.DeadlineExceeded))
|
||||
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
|
||||
|
@ -215,7 +247,9 @@ var _ = Describe("0-RTT", func() {
|
|||
connIDLen := l
|
||||
|
||||
It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() {
|
||||
tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientTLSConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -266,7 +300,9 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
// Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets.
|
||||
It("waits for a connection until the handshake is done", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
zeroRTTData := GeneratePRData(5 << 10)
|
||||
oneRTTData := PRData
|
||||
|
@ -351,7 +387,9 @@ var _ = Describe("0-RTT", func() {
|
|||
num0RTTDropped uint32
|
||||
)
|
||||
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -412,7 +450,9 @@ var _ = Describe("0-RTT", func() {
|
|||
var firstConnID, secondConnID *protocol.ConnectionID
|
||||
var firstCounter, secondCounter protocol.ByteCount
|
||||
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
|
||||
for len(data) > 0 {
|
||||
|
@ -485,9 +525,11 @@ var _ = Describe("0-RTT", func() {
|
|||
|
||||
It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
|
||||
const maxStreams = 1
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams,
|
||||
}))
|
||||
}), clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -524,15 +566,17 @@ var _ = Describe("0-RTT", func() {
|
|||
defer cancel()
|
||||
_, err = conn.OpenUniStreamSync(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
|
||||
Expect(conn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects 0-RTT when the server's stream limit decreased", func() {
|
||||
const maxStreams = 42
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: maxStreams,
|
||||
}))
|
||||
}), clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -548,6 +592,7 @@ var _ = Describe("0-RTT", func() {
|
|||
defer ln.Close()
|
||||
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
|
||||
defer proxy.Close()
|
||||
|
||||
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
|
||||
|
||||
// The client should send 0-RTT packets, but the server doesn't process them.
|
||||
|
@ -558,11 +603,15 @@ var _ = Describe("0-RTT", func() {
|
|||
})
|
||||
|
||||
It("rejects 0-RTT when the ALPN changed", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
// now close the listener and dial new connection with a different ALPN
|
||||
clientConf.NextProtos = []string{"new-alpn"}
|
||||
// switch to different ALPN on the server side
|
||||
tlsConf.NextProtos = []string{"new-alpn"}
|
||||
// Append to the client's ALPN.
|
||||
// crypto/tls will attempt to resume with the ALPN from the original connection
|
||||
clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn")
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
|
@ -587,7 +636,9 @@ var _ = Describe("0-RTT", func() {
|
|||
})
|
||||
|
||||
It("rejects 0-RTT when the application doesn't allow it", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
// now close the listener and dial new connection with a different ALPN
|
||||
tracer := newPacketTracer()
|
||||
|
@ -618,7 +669,9 @@ var _ = Describe("0-RTT", func() {
|
|||
tracer := newPacketTracer()
|
||||
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
|
||||
addFlowControlLimit(firstConf, 3)
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, firstConf, clientConf)
|
||||
|
||||
secondConf := getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
|
@ -662,7 +715,7 @@ var _ = Describe("0-RTT", func() {
|
|||
data, err := io.ReadAll(rstr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeTrue())
|
||||
Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue())
|
||||
Expect(serverConn.CloseWithError(0, "")).To(Succeed())
|
||||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
|
||||
|
@ -691,7 +744,9 @@ var _ = Describe("0-RTT", func() {
|
|||
connIDLen := l
|
||||
|
||||
It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
// now dial new connection with different transport parameters
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -767,7 +822,9 @@ var _ = Describe("0-RTT", func() {
|
|||
}
|
||||
|
||||
It("queues 0-RTT packets, if the Initial is delayed", func() {
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
|
||||
tlsConf := getTLSConfig()
|
||||
clientConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
|
@ -799,4 +856,86 @@ var _ = Describe("0-RTT", func() {
|
|||
Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10))
|
||||
Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0)))
|
||||
})
|
||||
|
||||
It("allows the application to attach data to the session ticket, for the server", func() {
|
||||
tlsConf := getTLSConfig()
|
||||
tlsConf.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) {
|
||||
ss.Extra = append(ss.Extra, []byte("foobar"))
|
||||
return tlsConf.EncryptTicket(cs, ss)
|
||||
}
|
||||
var unwrapped bool
|
||||
tlsConf.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) {
|
||||
defer GinkgoRecover()
|
||||
state, err := tlsConf.DecryptTicket(identity, cs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Expect(state.Extra).To(HaveLen(2))
|
||||
Expect(state.Extra[1]).To(Equal([]byte("foobar")))
|
||||
unwrapped = true
|
||||
return state, nil
|
||||
}
|
||||
clientTLSConf := getTLSClientConfig()
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
transfer0RTTData(
|
||||
ln,
|
||||
ln.Addr().(*net.UDPAddr).Port,
|
||||
10,
|
||||
clientTLSConf,
|
||||
getQuicConfig(nil),
|
||||
PRData,
|
||||
)
|
||||
Expect(unwrapped).To(BeTrue())
|
||||
})
|
||||
|
||||
It("allows the application to attach data to the session ticket, for the client", func() {
|
||||
tlsConf := getTLSConfig()
|
||||
clientTLSConf := getTLSClientConfig()
|
||||
var restored bool
|
||||
clientTLSConf.ClientSessionCache = &metadataClientSessionCache{
|
||||
toAdd: []byte("foobar"),
|
||||
restored: func(b []byte) {
|
||||
defer GinkgoRecover()
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
restored = true
|
||||
},
|
||||
cache: tls.NewLRUClientSessionCache(100),
|
||||
}
|
||||
dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf)
|
||||
|
||||
tracer := newPacketTracer()
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
transfer0RTTData(
|
||||
ln,
|
||||
ln.Addr().(*net.UDPAddr).Port,
|
||||
10,
|
||||
clientTLSConf,
|
||||
getQuicConfig(nil),
|
||||
PRData,
|
||||
)
|
||||
Expect(restored).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -336,12 +337,14 @@ type ClientHelloInfo struct {
|
|||
// ConnectionState records basic details about a QUIC connection
|
||||
type ConnectionState struct {
|
||||
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
|
||||
TLS handshake.ConnectionState
|
||||
TLS tls.ConnectionState
|
||||
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
|
||||
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
|
||||
// If datagram support was negotiated, datagrams can be sent and received using the
|
||||
// SendMessage and ReceiveMessage methods on the Connection.
|
||||
SupportsDatagrams bool
|
||||
// Used0RTT says if 0-RTT resumption was used.
|
||||
Used0RTT bool
|
||||
// Version is the QUIC version of the QUIC connection.
|
||||
Version VersionNumber
|
||||
}
|
||||
|
|
|
@ -5,11 +5,10 @@ import (
|
|||
"encoding/binary"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
||||
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
|
|
104
internal/handshake/cipher_suite.go
Normal file
104
internal/handshake/cipher_suite.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// These cipher suite implementations are copied from the standard library crypto/tls package.
|
||||
|
||||
const aeadNonceLength = 12
|
||||
|
||||
type cipherSuite struct {
|
||||
ID uint16
|
||||
Hash crypto.Hash
|
||||
KeyLen int
|
||||
AEAD func(key, nonceMask []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
func (s cipherSuite) IVLen() int { return aeadNonceLength }
|
||||
|
||||
func getCipherSuite(id uint16) *cipherSuite {
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadAESGCMTLS13}
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown cypher suite: %d", id))
|
||||
}
|
||||
}
|
||||
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
|
||||
// before each call.
|
||||
type xorNonceAEAD struct {
|
||||
nonceMask [aeadNonceLength]byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
|
||||
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
|
||||
|
||||
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
|
@ -7,9 +7,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -25,96 +24,20 @@ type quicVersionContextKey struct{}
|
|||
|
||||
var QUICVersionContextKey = &quicVersionContextKey{}
|
||||
|
||||
// TLS unexpected_message alert
|
||||
const alertUnexpectedMessage uint8 = 10
|
||||
|
||||
type messageType uint8
|
||||
|
||||
// TLS handshake message types.
|
||||
const (
|
||||
typeClientHello messageType = 1
|
||||
typeServerHello messageType = 2
|
||||
typeNewSessionTicket messageType = 4
|
||||
typeEncryptedExtensions messageType = 8
|
||||
typeCertificate messageType = 11
|
||||
typeCertificateRequest messageType = 13
|
||||
typeCertificateVerify messageType = 15
|
||||
typeFinished messageType = 20
|
||||
)
|
||||
|
||||
func (m messageType) String() string {
|
||||
switch m {
|
||||
case typeClientHello:
|
||||
return "ClientHello"
|
||||
case typeServerHello:
|
||||
return "ServerHello"
|
||||
case typeNewSessionTicket:
|
||||
return "NewSessionTicket"
|
||||
case typeEncryptedExtensions:
|
||||
return "EncryptedExtensions"
|
||||
case typeCertificate:
|
||||
return "Certificate"
|
||||
case typeCertificateRequest:
|
||||
return "CertificateRequest"
|
||||
case typeCertificateVerify:
|
||||
return "CertificateVerify"
|
||||
case typeFinished:
|
||||
return "Finished"
|
||||
default:
|
||||
return fmt.Sprintf("unknown message type: %d", m)
|
||||
}
|
||||
}
|
||||
|
||||
const clientSessionStateRevision = 3
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func newConn(local, remote net.Addr) net.Conn {
|
||||
return &conn{
|
||||
localAddr: local,
|
||||
remoteAddr: remote,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *tls.Config
|
||||
extraConf *qtls.ExtraConfig
|
||||
conn *qtls.Conn
|
||||
conn *qtls.QUICConn
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
messageChan chan []byte
|
||||
isReadingHandshakeMessage chan struct{}
|
||||
readFirstHandshakeMessage bool
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
paramsChan <-chan []byte
|
||||
|
||||
runner handshakeRunner
|
||||
|
||||
alertChan chan uint8
|
||||
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
|
||||
handshakeDone chan struct{}
|
||||
// is closed when Close() is called
|
||||
closeChan chan struct{}
|
||||
|
||||
zeroRTTParameters *wire.TransportParameters
|
||||
clientHelloWritten bool
|
||||
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
|
||||
zeroRTTParametersChan chan<- *wire.TransportParameters
|
||||
allow0RTT bool
|
||||
|
||||
|
@ -129,9 +52,6 @@ type cryptoSetup struct {
|
|||
|
||||
handshakeCompleteTime time.Time
|
||||
|
||||
readEncLevel protocol.EncryptionLevel
|
||||
writeEncLevel protocol.EncryptionLevel
|
||||
|
||||
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||
zeroRTTSealer LongHeaderSealer // only set for the client
|
||||
|
||||
|
@ -143,23 +63,20 @@ type cryptoSetup struct {
|
|||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
used0RTT atomic.Bool
|
||||
|
||||
oneRTTStream io.Writer
|
||||
aead *updatableAEAD
|
||||
has1RTTSealer bool
|
||||
has1RTTOpener bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ qtls.RecordLayer = &cryptoSetup{}
|
||||
_ CryptoSetup = &cryptoSetup{}
|
||||
)
|
||||
var _ CryptoSetup = &cryptoSetup{}
|
||||
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
initialStream, handshakeStream, oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
localAddr net.Addr,
|
||||
remoteAddr net.Addr,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
|
@ -172,28 +89,33 @@ func NewCryptoSetupClient(
|
|||
cs, clientHelloWritten := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
enable0RTT,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
version,
|
||||
)
|
||||
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
||||
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
|
||||
cs.tlsConf = tlsConf
|
||||
|
||||
cs.conn = qtls.QUICClient(quicConf)
|
||||
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
||||
|
||||
return cs, clientHelloWritten
|
||||
}
|
||||
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
initialStream, handshakeStream, oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
localAddr net.Addr,
|
||||
remoteAddr net.Addr,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
|
@ -206,29 +128,32 @@ func NewCryptoSetupServer(
|
|||
cs, _ := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
allow0RTT,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
version,
|
||||
)
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
cs.allow0RTT = allow0RTT
|
||||
|
||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
||||
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
|
||||
|
||||
cs.tlsConf = quicConf.TLSConfig
|
||||
cs.conn = qtls.QUICServer(quicConf)
|
||||
|
||||
return cs
|
||||
}
|
||||
|
||||
func newCryptoSetup(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
initialStream, handshakeStream, oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
enable0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
|
@ -240,51 +165,23 @@ func newCryptoSetup(
|
|||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
|
||||
}
|
||||
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
|
||||
zeroRTTParametersChan := make(chan *wire.TransportParameters, 1)
|
||||
cs := &cryptoSetup{
|
||||
tlsConf: tlsConf,
|
||||
return &cryptoSetup{
|
||||
initialStream: initialStream,
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
runner: runner,
|
||||
allow0RTT: enable0RTT,
|
||||
ourParams: tp,
|
||||
paramsChan: extHandler.TransportParameters(),
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
handshakeDone: make(chan struct{}),
|
||||
alertChan: make(chan uint8),
|
||||
clientHelloWrittenChan: make(chan struct{}),
|
||||
zeroRTTParametersChan: zeroRTTParametersChan,
|
||||
messageChan: make(chan []byte, 1),
|
||||
isReadingHandshakeMessage: make(chan struct{}),
|
||||
closeChan: make(chan struct{}),
|
||||
version: version,
|
||||
}
|
||||
var maxEarlyData uint32
|
||||
if enable0RTT {
|
||||
maxEarlyData = math.MaxUint32
|
||||
}
|
||||
cs.extraConf = &qtls.ExtraConfig{
|
||||
GetExtensions: extHandler.GetExtensions,
|
||||
ReceivedExtensions: extHandler.ReceivedExtensions,
|
||||
AlternativeRecordLayer: cs,
|
||||
EnforceNextProtoSelection: true,
|
||||
MaxEarlyData: maxEarlyData,
|
||||
Accept0RTT: cs.accept0RTT,
|
||||
Rejected0RTT: cs.rejected0RTT,
|
||||
Enable0RTT: enable0RTT,
|
||||
GetAppDataForSessionState: cs.marshalDataForSessionState,
|
||||
SetAppDataFromSessionState: cs.handleDataFromSessionState,
|
||||
}
|
||||
return cs, zeroRTTParametersChan
|
||||
}, zeroRTTParametersChan
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
|
||||
|
@ -301,142 +198,100 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
|
|||
return h.aead.SetLargestAcked(pn)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) RunHandshake() {
|
||||
// Handle errors that might occur when HandleData() is called.
|
||||
handshakeComplete := make(chan struct{})
|
||||
handshakeErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(h.handshakeDone)
|
||||
if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil {
|
||||
handshakeErrChan <- err
|
||||
return
|
||||
func (h *cryptoSetup) StartHandshake() error {
|
||||
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
done, err := h.handleEvent(ev)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
close(handshakeComplete)
|
||||
}()
|
||||
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
select {
|
||||
case err := <-handshakeErrChan:
|
||||
h.onError(0, err.Error())
|
||||
return
|
||||
case <-h.clientHelloWrittenChan:
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
h.mutex.Lock()
|
||||
h.handshakeCompleteTime = time.Now()
|
||||
h.mutex.Unlock()
|
||||
h.runner.OnHandshakeComplete()
|
||||
case <-h.closeChan:
|
||||
// wait until the Handshake() go routine has returned
|
||||
<-h.handshakeDone
|
||||
case alert := <-h.alertChan:
|
||||
handshakeErr := <-handshakeErrChan
|
||||
h.onError(alert, handshakeErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) onError(alert uint8, message string) {
|
||||
var err error
|
||||
if alert == 0 {
|
||||
err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message}
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.zeroRTTParametersChan <- h.zeroRTTParameters
|
||||
} else {
|
||||
err = qerr.NewLocalCryptoError(alert, message)
|
||||
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
|
||||
h.zeroRTTParametersChan <- nil
|
||||
}
|
||||
h.runner.OnError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the crypto setup.
|
||||
// It aborts the handshake, if it is still running.
|
||||
// It must only be called once.
|
||||
func (h *cryptoSetup) Close() error {
|
||||
close(h.closeChan)
|
||||
// wait until qtls.Handshake() actually returned
|
||||
<-h.handshakeDone
|
||||
return nil
|
||||
return h.conn.Close()
|
||||
}
|
||||
|
||||
// handleMessage handles a TLS handshake message.
|
||||
// HandleMessage handles a TLS handshake message.
|
||||
// It is called by the crypto streams when a new message is available.
|
||||
// It returns if it is done with messages on the same encryption level.
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
|
||||
msgType := messageType(data[0])
|
||||
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
|
||||
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
|
||||
h.onError(alertUnexpectedMessage, err.Error())
|
||||
return false
|
||||
}
|
||||
if encLevel != protocol.Encryption1RTT {
|
||||
select {
|
||||
case h.messageChan <- data:
|
||||
case <-h.handshakeDone: // handshake errored, nobody is going to consume this message
|
||||
return false
|
||||
}
|
||||
}
|
||||
if encLevel == protocol.Encryption1RTT {
|
||||
h.messageChan <- data
|
||||
h.handlePostHandshakeMessage()
|
||||
return false
|
||||
}
|
||||
readLoop:
|
||||
for {
|
||||
select {
|
||||
case data := <-h.paramsChan:
|
||||
if data == nil {
|
||||
h.onError(0x6d, "missing quic_transport_parameters extension")
|
||||
} else {
|
||||
h.handleTransportParameters(data)
|
||||
}
|
||||
case <-h.isReadingHandshakeMessage:
|
||||
break readLoop
|
||||
case <-h.handshakeDone:
|
||||
break readLoop
|
||||
case <-h.closeChan:
|
||||
break readLoop
|
||||
}
|
||||
}
|
||||
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
|
||||
// but only if a handshake opener and sealer was created.
|
||||
// Otherwise, a HelloRetryRequest was performed.
|
||||
// We're done with the Handshake encryption level after processing the Finished message.
|
||||
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
|
||||
msgType == typeFinished
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
var expected protocol.EncryptionLevel
|
||||
switch msgType {
|
||||
case typeClientHello, typeServerHello:
|
||||
expected = protocol.EncryptionInitial
|
||||
case typeEncryptedExtensions,
|
||||
typeCertificate,
|
||||
typeCertificateRequest,
|
||||
typeCertificateVerify,
|
||||
typeFinished:
|
||||
expected = protocol.EncryptionHandshake
|
||||
case typeNewSessionTicket:
|
||||
expected = protocol.Encryption1RTT
|
||||
default:
|
||||
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
||||
}
|
||||
if encLevel != expected {
|
||||
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.handleMessage(data, encLevel); err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) {
|
||||
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
done, err := h.handleEvent(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
|
||||
switch ev.Kind {
|
||||
case qtls.QUICNoEvent:
|
||||
return true, nil
|
||||
case qtls.QUICSetReadSecret:
|
||||
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
case qtls.QUICSetWriteSecret:
|
||||
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
case qtls.QUICTransportParameters:
|
||||
return false, h.handleTransportParameters(ev.Data)
|
||||
case qtls.QUICTransportParametersRequired:
|
||||
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
|
||||
return false, nil
|
||||
case qtls.QUICRejectedEarlyData:
|
||||
h.rejected0RTT()
|
||||
return false, nil
|
||||
case qtls.QUICWriteData:
|
||||
return false, h.WriteRecord(ev.Level, ev.Data)
|
||||
case qtls.QUICHandshakeDone:
|
||||
h.handshakeComplete()
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
h.runner.OnError(&qerr.TransportError{
|
||||
ErrorCode: qerr.TransportParameterError,
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
h.peerParams = &tp
|
||||
h.runner.OnReceivedParams(h.peerParams)
|
||||
return nil
|
||||
}
|
||||
|
||||
// must be called after receiving the transport parameters
|
||||
|
@ -477,17 +332,32 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
|
|||
return &tp, nil
|
||||
}
|
||||
|
||||
// only valid for the server
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
var appData []byte
|
||||
// Save transport parameters to the session ticket if we're allowing 0-RTT.
|
||||
if h.extraConf.MaxEarlyData > 0 {
|
||||
appData = (&sessionTicket{
|
||||
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
||||
return (&sessionTicket{
|
||||
Parameters: h.ourParams,
|
||||
RTT: h.rttStats.SmoothedRTT(),
|
||||
}).Marshal()
|
||||
}
|
||||
return h.conn.GetSessionTicket(appData)
|
||||
|
||||
// GetSessionTicket generates a new session ticket.
|
||||
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
|
||||
// It is only valid for the server.
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
if h.tlsConf.SessionTicketsDisabled {
|
||||
return nil, nil
|
||||
}
|
||||
if err := h.conn.SendSessionTicket(h.allow0RTT); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ev := h.conn.NextEvent()
|
||||
if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication {
|
||||
panic("crypto/tls bug: where's my session ticket?")
|
||||
}
|
||||
ticket := ev.Data
|
||||
if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent {
|
||||
panic("crypto/tls bug: why more than one ticket?")
|
||||
}
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// accept0RTT is called for the server when receiving the client's session ticket.
|
||||
|
@ -526,60 +396,12 @@ func (h *cryptoSetup) rejected0RTT() {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handlePostHandshakeMessage() {
|
||||
// make sure the handshake has already completed
|
||||
<-h.handshakeDone
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
// h.alertChan is an unbuffered channel.
|
||||
// If an error occurs during conn.HandlePostHandshakeMessage,
|
||||
// it will be sent on this channel.
|
||||
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
|
||||
alertChan := make(chan uint8, 1)
|
||||
go func() {
|
||||
<-h.isReadingHandshakeMessage
|
||||
select {
|
||||
case alert := <-h.alertChan:
|
||||
alertChan <- alert
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
if err := h.conn.HandlePostHandshakeMessage(); err != nil {
|
||||
select {
|
||||
case <-h.closeChan:
|
||||
case alert := <-alertChan:
|
||||
h.onError(alert, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadHandshakeMessage is called by TLS.
|
||||
// It blocks until a new handshake message is available.
|
||||
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
if !h.readFirstHandshakeMessage {
|
||||
h.readFirstHandshakeMessage = true
|
||||
} else {
|
||||
select {
|
||||
case h.isReadingHandshakeMessage <- struct{}{}:
|
||||
case <-h.closeChan:
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case msg := <-h.messageChan:
|
||||
return msg, nil
|
||||
case <-h.closeChan:
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
h.mutex.Lock()
|
||||
switch encLevel {
|
||||
case qtls.Encryption0RTT:
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case qtls.QUICEncryptionLevelEarly:
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
panic("Received 0-RTT read key for the client")
|
||||
}
|
||||
|
@ -587,16 +409,11 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
|||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.mutex.Unlock()
|
||||
h.used0RTT.Store(true)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
|
||||
}
|
||||
return
|
||||
case qtls.EncryptionHandshake:
|
||||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeOpener = newHandshakeOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
|
@ -606,8 +423,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
|||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.EncryptionApplication:
|
||||
h.readEncLevel = protocol.Encryption1RTT
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
h.aead.SetReadKey(suite, trafficSecret)
|
||||
h.has1RTTOpener = true
|
||||
if h.logger.Debug() {
|
||||
|
@ -617,15 +433,18 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
|||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.runner.OnReceivedReadKeys()
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
h.mutex.Lock()
|
||||
switch encLevel {
|
||||
case qtls.Encryption0RTT:
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case qtls.QUICEncryptionLevelEarly:
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
panic("Received 0-RTT write key for the server")
|
||||
}
|
||||
|
@ -640,9 +459,9 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
|||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
|
||||
}
|
||||
// don't set used0RTT here. 0-RTT might still get rejected.
|
||||
return
|
||||
case qtls.EncryptionHandshake:
|
||||
h.writeEncLevel = protocol.EncryptionHandshake
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeSealer = newHandshakeSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
|
@ -652,14 +471,15 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
|||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.EncryptionApplication:
|
||||
h.writeEncLevel = protocol.Encryption1RTT
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
h.aead.SetWriteKey(suite, trafficSecret)
|
||||
h.has1RTTSealer = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.zeroRTTSealer != nil {
|
||||
// Once we receive handshake keys, we know that 0-RTT was not rejected.
|
||||
h.used0RTT.Store(true)
|
||||
h.zeroRTTSealer = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.tracer != nil {
|
||||
|
@ -671,45 +491,30 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
|||
}
|
||||
h.mutex.Unlock()
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
||||
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
//nolint:exhaustive // LS records can only be written for Initial and Handshake.
|
||||
switch h.writeEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
var str io.Writer
|
||||
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
|
||||
switch encLevel {
|
||||
case qtls.QUICEncryptionLevelInitial:
|
||||
// assume that the first WriteRecord call contains the ClientHello
|
||||
n, err := h.initialStream.Write(p)
|
||||
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
|
||||
h.clientHelloWritten = true
|
||||
close(h.clientHelloWrittenChan)
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.zeroRTTParametersChan <- h.zeroRTTParameters
|
||||
} else {
|
||||
h.logger.Debugf("Not doing 0-RTT.")
|
||||
h.zeroRTTParametersChan <- nil
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
case protocol.EncryptionHandshake:
|
||||
return h.handshakeStream.Write(p)
|
||||
str = h.initialStream
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
str = h.handshakeStream
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
str = h.oneRTTStream
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SendAlert(alert uint8) {
|
||||
select {
|
||||
case h.alertChan <- alert:
|
||||
case <-h.closeChan:
|
||||
// no need to send an alert when we've already closed
|
||||
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
|
||||
}
|
||||
_, err := str.Write(p)
|
||||
return err
|
||||
}
|
||||
|
||||
// used a callback in the handshakeSealer and handshakeOpener
|
||||
|
@ -722,6 +527,11 @@ func (h *cryptoSetup) dropInitialKeys() {
|
|||
h.logger.Debugf("Dropping Initial keys.")
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handshakeComplete() {
|
||||
h.handshakeCompleteTime = time.Now()
|
||||
h.runner.OnHandshakeComplete()
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||
h.aead.SetHandshakeConfirmed()
|
||||
// drop Handshake keys
|
||||
|
@ -839,5 +649,15 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
return qtls.GetConnectionState(h.conn)
|
||||
return ConnectionState{
|
||||
ConnectionState: h.conn.ConnectionState(),
|
||||
Used0RTT: h.used0RTT.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
|
||||
return qerr.NewLocalCryptoError(uint8(alertErr), err.Error())
|
||||
}
|
||||
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
|
@ -23,12 +22,10 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
|
||||
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
|
||||
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
|
||||
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
|
||||
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
|
||||
}
|
||||
const (
|
||||
typeClientHello = 1
|
||||
typeNewSessionTicket = 4
|
||||
)
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
|
@ -80,54 +77,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
})
|
||||
|
||||
It("returns Handshake() when an error occurs in qtls", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
runner,
|
||||
testdata.GetTLSConfig(),
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.Version1,
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.RunHandshake()
|
||||
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
|
||||
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
|
||||
ErrorMessage: "local error: tls: unexpected message",
|
||||
})))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||
handledMessage := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
|
||||
close(handledMessage)
|
||||
}()
|
||||
Eventually(handledMessage).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("handles qtls errors occurring before during ClientHello generation", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
tlsConf := testdata.GetTLSConfig()
|
||||
tlsConf.InsecureSkipVerify = true
|
||||
|
@ -135,11 +85,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cl, _ := NewCryptoSetupClient(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{},
|
||||
runner,
|
||||
NewMockHandshakeRunner(mockCtrl),
|
||||
tlsConf,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
|
@ -148,32 +97,21 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.Version1,
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cl.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
|
||||
Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
|
||||
ErrorCode: qerr.InternalError,
|
||||
ErrorMessage: "tls: invalid NextProtos value",
|
||||
})))
|
||||
}))
|
||||
})
|
||||
|
||||
It("errors when a message is received at the wrong encryption level", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
runner,
|
||||
testdata.GetTLSConfig(),
|
||||
|
@ -184,90 +122,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.Version1,
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
Expect(server.StartHandshake()).To(Succeed())
|
||||
|
||||
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
||||
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
||||
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
|
||||
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
|
||||
ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake",
|
||||
})))
|
||||
|
||||
// make the go routine return
|
||||
Expect(server.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns Handshake() when handling a message fails", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
runner,
|
||||
serverConf,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.Version1,
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.RunHandshake()
|
||||
var err error
|
||||
Expect(sErrChan).To(Receive(&err))
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
|
||||
Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
|
||||
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns Handshake() when it is closed", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
NewMockHandshakeRunner(mockCtrl),
|
||||
serverConf,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
protocol.Version1,
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
Expect(server.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
|
||||
// wrong encryption level
|
||||
err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
|
||||
})
|
||||
|
||||
Context("doing the handshake", func() {
|
||||
|
@ -297,55 +158,32 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
return rttStats
|
||||
}
|
||||
|
||||
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
|
||||
server CryptoSetup, sChunkChan <-chan chunk,
|
||||
) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) {
|
||||
Expect(client.StartHandshake()).To(Succeed())
|
||||
Expect(server.StartHandshake()).To(Succeed())
|
||||
|
||||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
msgType := messageType(c.data[0])
|
||||
finished := server.HandleMessage(c.data, c.encLevel)
|
||||
if msgType == typeFinished {
|
||||
Expect(finished).To(BeTrue())
|
||||
} else if msgType == typeClientHello {
|
||||
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
|
||||
_, err := server.GetHandshakeOpener()
|
||||
Expect(finished).To(Equal(err == nil))
|
||||
} else {
|
||||
Expect(finished).To(BeFalse())
|
||||
Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed())
|
||||
continue
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case c := <-sChunkChan:
|
||||
msgType := messageType(c.data[0])
|
||||
finished := client.HandleMessage(c.data, c.encLevel)
|
||||
if msgType == typeFinished {
|
||||
Expect(finished).To(BeTrue())
|
||||
} else if msgType == typeServerHello {
|
||||
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
|
||||
} else {
|
||||
Expect(finished).To(BeFalse())
|
||||
Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed())
|
||||
continue
|
||||
default:
|
||||
}
|
||||
case <-done: // handshake complete
|
||||
return
|
||||
// no more messages to send from client and server. Handshake complete?
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
server.RunHandshake()
|
||||
ticket, err := server.GetSessionTicket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if ticket != nil {
|
||||
client.HandleMessage(ticket, protocol.Encryption1RTT)
|
||||
Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
|
||||
}
|
||||
}()
|
||||
|
||||
client.RunHandshake()
|
||||
Eventually(done).Should(BeClosed())
|
||||
}
|
||||
|
||||
handshakeWithTLSConf := func(
|
||||
|
@ -359,15 +197,14 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cErrChan := make(chan error, 1)
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
|
||||
cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
|
||||
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
|
||||
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
|
||||
client, clientHelloWrittenChan := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
clientTransportParameters,
|
||||
cRunner,
|
||||
clientConf,
|
||||
|
@ -383,7 +220,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sErrChan := make(chan error, 1)
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
|
||||
sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
|
||||
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
|
||||
if serverTransportParameters.StatelessResetToken == nil {
|
||||
var token protocol.StatelessResetToken
|
||||
|
@ -392,9 +229,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
serverTransportParameters,
|
||||
sRunner,
|
||||
serverConf,
|
||||
|
@ -462,9 +298,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
client, chChan := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{},
|
||||
runner,
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
|
@ -475,24 +310,15 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.Version1,
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
client.RunHandshake()
|
||||
close(done)
|
||||
}()
|
||||
Expect(client.StartHandshake()).To(Succeed())
|
||||
var ch chunk
|
||||
Eventually(cChunkChan).Should(Receive(&ch))
|
||||
Eventually(chChan).Should(Receive(BeNil()))
|
||||
// make sure the whole ClientHello was written
|
||||
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
||||
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
||||
Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello))
|
||||
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
||||
Expect(len(ch.data) - 4).To(Equal(length))
|
||||
|
||||
// make the go routine return
|
||||
Expect(client.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("receives transport parameters", func() {
|
||||
|
@ -500,14 +326,14 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second}
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
cTransportParameters,
|
||||
cRunner,
|
||||
clientConf,
|
||||
|
@ -521,6 +347,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
var token protocol.StatelessResetToken
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
sTransportParameters := &wire.TransportParameters{
|
||||
|
@ -531,9 +358,8 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
sTransportParameters,
|
||||
sRunner,
|
||||
serverConf,
|
||||
|
@ -561,13 +387,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
cRunner,
|
||||
clientConf,
|
||||
|
@ -581,14 +407,14 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
|
||||
sRunner,
|
||||
serverConf,
|
||||
|
@ -608,25 +434,23 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
|
||||
// inject an invalid session ticket
|
||||
cRunner.EXPECT().OnError(&qerr.TransportError{
|
||||
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
|
||||
ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake",
|
||||
})
|
||||
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
||||
client.HandleMessage(b, protocol.EncryptionHandshake)
|
||||
err := client.HandleMessage(b, protocol.EncryptionHandshake)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
|
||||
})
|
||||
|
||||
It("errors when handling the NewSessionTicket fails", func() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
cRunner,
|
||||
clientConf,
|
||||
|
@ -640,14 +464,14 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
nil,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
|
||||
sRunner,
|
||||
serverConf,
|
||||
|
@ -667,13 +491,11 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
|
||||
// inject an invalid session ticket
|
||||
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
|
||||
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
||||
err := client.HandleMessage(b, protocol.Encryption1RTT)
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
|
||||
Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
|
||||
})
|
||||
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
||||
client.HandleMessage(b, protocol.Encryption1RTT)
|
||||
})
|
||||
|
||||
It("uses session resumption", func() {
|
||||
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
||||
|
@ -785,7 +607,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
|
||||
|
||||
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
||||
csc.EXPECT().Put(gomock.Any(), nil)
|
||||
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
|
||||
clientRTTStats := &utils.RTTStats{}
|
||||
|
@ -840,7 +661,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
|
||||
|
||||
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
||||
csc.EXPECT().Put(gomock.Any(), nil)
|
||||
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
|
||||
clientRTTStats := &utils.RTTStats{}
|
||||
|
|
|
@ -6,8 +6,6 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -41,8 +39,8 @@ func splitHexString(s string) (slice []byte) {
|
|||
return
|
||||
}
|
||||
|
||||
var cipherSuites = []*qtls.CipherSuiteTLS13{
|
||||
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256),
|
||||
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384),
|
||||
qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256),
|
||||
var cipherSuites = []*cipherSuite{
|
||||
getCipherSuite(tls.TLS_AES_128_GCM_SHA256),
|
||||
getCipherSuite(tls.TLS_AES_256_GCM_SHA384),
|
||||
getCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256),
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"golang.org/x/crypto/chacha20"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
|
@ -25,7 +24,7 @@ func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
|
|||
return "quic hp"
|
||||
}
|
||||
|
||||
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
|
||||
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
|
||||
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
|
@ -45,7 +44,7 @@ type aesHeaderProtector struct {
|
|||
|
||||
var _ headerProtector = &aesHeaderProtector{}
|
||||
|
||||
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
block, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
|
@ -90,7 +89,7 @@ type chachaHeaderProtector struct {
|
|||
|
||||
var _ headerProtector = &chachaHeaderProtector{}
|
||||
|
||||
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
|
||||
p := &chachaHeaderProtector{
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -29,12 +28,7 @@ func getSalt(v protocol.VersionNumber) []byte {
|
|||
return quicSaltV1
|
||||
}
|
||||
|
||||
var initialSuite = &qtls.CipherSuiteTLS13{
|
||||
ID: tls.TLS_AES_128_GCM_SHA256,
|
||||
KeyLen: 16,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
|
||||
|
@ -50,8 +44,8 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
|
|||
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
|
||||
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
|
||||
|
||||
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
|
||||
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
|
||||
encrypter := initialSuite.AEAD(myKey, myIV)
|
||||
decrypter := initialSuite.AEAD(otherKey, otherIV)
|
||||
|
||||
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
|
||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -22,9 +22,6 @@ var (
|
|||
ErrDecryptionFailed = errors.New("decryption failed")
|
||||
)
|
||||
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
type ConnectionState = qtls.ConnectionState
|
||||
|
||||
type headerDecryptor interface {
|
||||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
@ -56,28 +53,26 @@ type ShortHeaderSealer interface {
|
|||
KeyPhase() protocol.KeyPhaseBit
|
||||
}
|
||||
|
||||
// A tlsExtensionHandler sends and received the QUIC TLS extension.
|
||||
type tlsExtensionHandler interface {
|
||||
GetExtensions(msgType uint8) []qtls.Extension
|
||||
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
|
||||
TransportParameters() <-chan []byte
|
||||
}
|
||||
|
||||
type handshakeRunner interface {
|
||||
OnReceivedParams(*wire.TransportParameters)
|
||||
OnHandshakeComplete()
|
||||
OnError(error)
|
||||
OnReceivedReadKeys()
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
type ConnectionState struct {
|
||||
tls.ConnectionState
|
||||
Used0RTT bool
|
||||
}
|
||||
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
RunHandshake()
|
||||
StartHandshake() error
|
||||
io.Closer
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
GetSessionTicket() ([]byte, error)
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) error
|
||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||
SetHandshakeConfirmed()
|
||||
ConnectionState() ConnectionState
|
||||
|
|
|
@ -47,18 +47,6 @@ func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Ca
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0)
|
||||
}
|
||||
|
||||
// OnError mocks base method.
|
||||
func (m *MockHandshakeRunner) OnError(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnError", arg0)
|
||||
}
|
||||
|
||||
// OnError indicates an expected call of OnError.
|
||||
func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0)
|
||||
}
|
||||
|
||||
// OnHandshakeComplete mocks base method.
|
||||
func (m *MockHandshakeRunner) OnHandshakeComplete() {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -82,3 +70,15 @@ func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *g
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0)
|
||||
}
|
||||
|
||||
// OnReceivedReadKeys mocks base method.
|
||||
func (m *MockHandshakeRunner) OnReceivedReadKeys() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnReceivedReadKeys")
|
||||
}
|
||||
|
||||
// OnReceivedReadKeys indicates an expected call of OnReceivedReadKeys.
|
||||
func (mr *MockHandshakeRunnerMockRecorder) OnReceivedReadKeys() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedReadKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedReadKeys))
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const sessionTicketRevision = 2
|
||||
const sessionTicketRevision = 3
|
||||
|
||||
type sessionTicket struct {
|
||||
Parameters *wire.TransportParameters
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
const quicTLSExtensionType = 0x39
|
||||
|
||||
type extensionHandler struct {
|
||||
ourParams []byte
|
||||
paramsChan chan []byte
|
||||
|
||||
extensionType uint16
|
||||
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ tlsExtensionHandler = &extensionHandler{}
|
||||
|
||||
// newExtensionHandler creates a new extension handler
|
||||
func newExtensionHandler(params []byte, pers protocol.Perspective) tlsExtensionHandler {
|
||||
return &extensionHandler{
|
||||
ourParams: params,
|
||||
paramsChan: make(chan []byte),
|
||||
perspective: pers,
|
||||
extensionType: quicTLSExtensionType,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
|
||||
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
|
||||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
|
||||
return nil
|
||||
}
|
||||
return []qtls.Extension{{
|
||||
Type: h.extensionType,
|
||||
Data: h.ourParams,
|
||||
}}
|
||||
}
|
||||
|
||||
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
|
||||
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
|
||||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
|
||||
return
|
||||
}
|
||||
|
||||
var data []byte
|
||||
for _, ext := range exts {
|
||||
if ext.Type == h.extensionType {
|
||||
data = ext.Data
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
h.paramsChan <- data
|
||||
}
|
||||
|
||||
func (h *extensionHandler) TransportParameters() <-chan []byte {
|
||||
return h.paramsChan
|
||||
}
|
|
@ -1,165 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TLS Extension Handler, for the server", func() {
|
||||
var (
|
||||
handlerServer tlsExtensionHandler
|
||||
handlerClient tlsExtensionHandler
|
||||
)
|
||||
|
||||
JustBeforeEach(func() {
|
||||
handlerServer = newExtensionHandler([]byte("foobar"), protocol.PerspectiveServer)
|
||||
handlerClient = newExtensionHandler([]byte("raboof"), protocol.PerspectiveClient)
|
||||
})
|
||||
|
||||
Context("for the server", func() {
|
||||
Context("sending", func() {
|
||||
It("only adds TransportParameters for the Encrypted Extensions", func() {
|
||||
// test 2 other handshake types
|
||||
Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty())
|
||||
Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("adds TransportParameters to the EncryptedExtensions message", func() {
|
||||
exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions))
|
||||
Expect(exts).To(HaveLen(1))
|
||||
Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType))
|
||||
Expect(exts[0].Data).To(Equal([]byte("foobar")))
|
||||
})
|
||||
})
|
||||
|
||||
Context("receiving", func() {
|
||||
var chExts []qtls.Extension
|
||||
|
||||
JustBeforeEach(func() {
|
||||
chExts = handlerClient.GetExtensions(uint8(typeClientHello))
|
||||
Expect(chExts).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("sends the extension on the channel", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
Eventually(handlerServer.TransportParameters()).Should(Receive(&data))
|
||||
Expect(data).To(Equal([]byte("raboof")))
|
||||
})
|
||||
|
||||
It("sends nil on the channel if the extension is missing", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handlerServer.ReceivedExtensions(uint8(typeClientHello), nil)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
Eventually(handlerServer.TransportParameters()).Should(Receive(&data))
|
||||
Expect(data).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores extensions with different code points", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}}
|
||||
handlerServer.ReceivedExtensions(uint8(typeClientHello), exts)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
Eventually(handlerServer.TransportParameters()).Should(Receive())
|
||||
Expect(data).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores extensions that are not sent with the ClientHello", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handlerServer.ReceivedExtensions(uint8(typeFinished), chExts)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(handlerServer.TransportParameters()).ShouldNot(Receive())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("for the client", func() {
|
||||
Context("sending", func() {
|
||||
It("only adds TransportParameters for the Encrypted Extensions", func() {
|
||||
// test 2 other handshake types
|
||||
Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty())
|
||||
Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("adds TransportParameters to the ClientHello message", func() {
|
||||
exts := handlerClient.GetExtensions(uint8(typeClientHello))
|
||||
Expect(exts).To(HaveLen(1))
|
||||
Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType))
|
||||
Expect(exts[0].Data).To(Equal([]byte("raboof")))
|
||||
})
|
||||
})
|
||||
|
||||
Context("receiving", func() {
|
||||
var chExts []qtls.Extension
|
||||
|
||||
JustBeforeEach(func() {
|
||||
chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions))
|
||||
Expect(chExts).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("sends the extension on the channel", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
Eventually(handlerClient.TransportParameters()).Should(Receive(&data))
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("sends nil on the channel if the extension is missing", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
Eventually(handlerClient.TransportParameters()).Should(Receive(&data))
|
||||
Expect(data).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores extensions with different code points", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}}
|
||||
handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts)
|
||||
}()
|
||||
|
||||
var data []byte
|
||||
Eventually(handlerClient.TransportParameters()).Should(Receive())
|
||||
Expect(data).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores extensions that are not sent with the EncryptedExtensions", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handlerClient.ReceivedExtensions(uint8(typeFinished), chExts)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(handlerClient.TransportParameters()).ShouldNot(Receive())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -10,7 +10,6 @@ import (
|
|||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
@ -24,7 +23,7 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
|||
var FirstKeyUpdateInterval uint64 = 100
|
||||
|
||||
type updatableAEAD struct {
|
||||
suite *qtls.CipherSuiteTLS13
|
||||
suite *cipherSuite
|
||||
|
||||
keyPhase protocol.KeyPhase
|
||||
largestAcked protocol.PacketNumber
|
||||
|
@ -121,7 +120,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
|
|||
// SetReadKey sets the read key.
|
||||
// For the client, this function is called before SetWriteKey.
|
||||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
|
||||
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite == nil {
|
||||
|
@ -135,7 +134,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [
|
|||
// SetWriteKey sets the write key.
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetWriteKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite == nil {
|
||||
|
@ -146,7 +145,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret
|
|||
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
|
||||
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
|
||||
a.nonceBuf = make([]byte, aead.NonceSize())
|
||||
a.aeadOverhead = aead.Overhead()
|
||||
a.suite = suite
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/quic-go/quic-go/internal/handshake"
|
||||
protocol "github.com/quic-go/quic-go/internal/protocol"
|
||||
qtls "github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
// MockCryptoSetup is a mock of CryptoSetup interface.
|
||||
|
@ -63,10 +62,10 @@ func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
|
|||
}
|
||||
|
||||
// ConnectionState mocks base method.
|
||||
func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState {
|
||||
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(qtls.ConnectionState)
|
||||
ret0, _ := ret[0].(handshake.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -212,10 +211,10 @@ func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call {
|
|||
}
|
||||
|
||||
// HandleMessage mocks base method.
|
||||
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -225,18 +224,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
|
||||
}
|
||||
|
||||
// RunHandshake mocks base method.
|
||||
func (m *MockCryptoSetup) RunHandshake() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RunHandshake")
|
||||
}
|
||||
|
||||
// RunHandshake indicates an expected call of RunHandshake.
|
||||
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
|
||||
}
|
||||
|
||||
// SetHandshakeConfirmed mocks base method.
|
||||
func (m *MockCryptoSetup) SetHandshakeConfirmed() {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -262,3 +249,17 @@ func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *go
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0)
|
||||
}
|
||||
|
||||
// StartHandshake mocks base method.
|
||||
func (m *MockCryptoSetup) StartHandshake() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartHandshake")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartHandshake indicates an expected call of StartHandshake.
|
||||
func (mr *MockCryptoSetupMockRecorder) StartHandshake() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).StartHandshake))
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func (e TransportErrorCode) Message() string {
|
|||
if !e.IsCryptoError() {
|
||||
return ""
|
||||
}
|
||||
return qtls.Alert(e - 0x100).Error()
|
||||
return qtls.AlertError(e - 0x100).Error()
|
||||
}
|
||||
|
||||
func (e TransportErrorCode) String() string {
|
||||
|
|
66
internal/qtls/cipher_suite_go121.go
Normal file
66
internal/qtls/cipher_suite_go121.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
|
||||
var cipherSuitesTLS13 []unsafe.Pointer
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
|
||||
var defaultCipherSuitesTLS13 []uint16
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
|
||||
var defaultCipherSuitesTLS13NoAES []uint16
|
||||
|
||||
var cipherSuitesModified bool
|
||||
|
||||
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
|
||||
// such that it only contains the cipher suite with the chosen id.
|
||||
// The reset function returned resets them back to the original value.
|
||||
func SetCipherSuite(id uint16) (reset func()) {
|
||||
if cipherSuitesModified {
|
||||
panic("cipher suites modified multiple times without resetting")
|
||||
}
|
||||
cipherSuitesModified = true
|
||||
|
||||
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
|
||||
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
|
||||
}
|
||||
defaultCipherSuitesTLS13 = []uint16{id}
|
||||
defaultCipherSuitesTLS13NoAES = []uint16{id}
|
||||
|
||||
return func() {
|
||||
cipherSuitesTLS13 = origCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
|
||||
cipherSuitesModified = false
|
||||
}
|
||||
}
|
52
internal/qtls/cipher_suite_test.go
Normal file
52
internal/qtls/cipher_suite_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Setting the Cipher Suite", func() {
|
||||
for _, cs := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_AES_256_GCM_SHA384} {
|
||||
cs := cs
|
||||
|
||||
It(fmt.Sprintf("selects %s", tls.CipherSuiteName(cs)), func() {
|
||||
reset := SetCipherSuite(cs)
|
||||
defer reset()
|
||||
|
||||
ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = conn.Read(make([]byte, 10))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.(*tls.Conn).ConnectionState().CipherSuite).To(Equal(cs))
|
||||
}()
|
||||
|
||||
conn, err := tls.Dial(
|
||||
"tcp4",
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
|
||||
&tls.Config{RootCAs: testdata.GetRootCA()},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().CipherSuite).To(Equal(cs))
|
||||
Expect(conn.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
}
|
||||
})
|
61
internal/qtls/client_session_cache.go
Normal file
61
internal/qtls/client_session_cache.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
getData func() []byte
|
||||
setData func([]byte)
|
||||
wrapped tls.ClientSessionCache
|
||||
}
|
||||
|
||||
var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||
if cs == nil {
|
||||
c.wrapped.Put(key, nil)
|
||||
return
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil || state == nil {
|
||||
c.wrapped.Put(key, cs)
|
||||
return
|
||||
}
|
||||
state.Extra = append(state.Extra, addExtraPrefix(c.getData()))
|
||||
newCS, err := tls.NewResumptionState(ticket, state)
|
||||
if err != nil {
|
||||
// It's not clear why this would error. Just save the original state.
|
||||
c.wrapped.Put(key, cs)
|
||||
return
|
||||
}
|
||||
c.wrapped.Put(key, newCS)
|
||||
}
|
||||
|
||||
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||
cs, ok := c.wrapped.Get(key)
|
||||
if !ok || cs == nil {
|
||||
return cs, ok
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil {
|
||||
// It's not clear why this would error.
|
||||
// Remove the ticket from the session cache, so we don't run into this error over and over again
|
||||
c.wrapped.Put(key, nil)
|
||||
return nil, false
|
||||
}
|
||||
// restore QUIC transport parameters and RTT stored in state.Extra
|
||||
if extra := findExtraData(state.Extra); extra != nil {
|
||||
c.setData(extra)
|
||||
}
|
||||
session, err := tls.NewResumptionState(ticket, state)
|
||||
if err != nil {
|
||||
// It's not clear why this would error.
|
||||
// Remove the ticket from the session cache, so we don't run into this error over and over again
|
||||
c.wrapped.Put(key, nil)
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
82
internal/qtls/client_session_cache_test.go
Normal file
82
internal/qtls/client_session_cache_test.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Client Session Cache", func() {
|
||||
It("adds data to and restores data from a session ticket", func() {
|
||||
ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = conn.Read(make([]byte, 10))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}()
|
||||
|
||||
restored := make(chan []byte, 1)
|
||||
clientConf := &tls.Config{
|
||||
RootCAs: testdata.GetRootCA(),
|
||||
ClientSessionCache: &clientSessionCache{
|
||||
wrapped: tls.NewLRUClientSessionCache(10),
|
||||
getData: func() []byte { return []byte("session") },
|
||||
setData: func(data []byte) { restored <- data },
|
||||
},
|
||||
}
|
||||
conn, err := tls.Dial(
|
||||
"tcp4",
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().DidResume).To(BeFalse())
|
||||
Expect(restored).To(HaveLen(0))
|
||||
_, err = conn.Read(make([]byte, 10))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.Close()).To(Succeed())
|
||||
|
||||
// make sure the cache can deal with nonsensical inputs
|
||||
clientConf.ClientSessionCache.Put("foo", nil)
|
||||
clientConf.ClientSessionCache.Put("bar", &tls.ClientSessionState{})
|
||||
|
||||
conn, err = tls.Dial(
|
||||
"tcp4",
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port),
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().DidResume).To(BeTrue())
|
||||
var restoredData []byte
|
||||
Expect(restored).To(Receive(&restoredData))
|
||||
Expect(restoredData).To(Equal([]byte("session")))
|
||||
Expect(conn.Close()).To(Succeed())
|
||||
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
|
@ -1,145 +0,0 @@
|
|||
//go:build go1.19 && !go1.20
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/quic-go/qtls-go1-19"
|
||||
)
|
||||
|
||||
type (
|
||||
// Alert is a TLS alert
|
||||
Alert = qtls.Alert
|
||||
// A Certificate is qtls.Certificate.
|
||||
Certificate = qtls.Certificate
|
||||
// CertificateRequestInfo contains information about a certificate request.
|
||||
CertificateRequestInfo = qtls.CertificateRequestInfo
|
||||
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
|
||||
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
|
||||
// ClientHelloInfo contains information about a ClientHello.
|
||||
ClientHelloInfo = qtls.ClientHelloInfo
|
||||
// ClientSessionCache is a cache used for session resumption.
|
||||
ClientSessionCache = qtls.ClientSessionCache
|
||||
// ClientSessionState is a state needed for session resumption.
|
||||
ClientSessionState = qtls.ClientSessionState
|
||||
// A Config is a qtls.Config.
|
||||
Config = qtls.Config
|
||||
// A Conn is a qtls.Conn.
|
||||
Conn = qtls.Conn
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
ConnectionState = qtls.ConnectionStateWith0RTT
|
||||
// EncryptionLevel is the encryption level of a message.
|
||||
EncryptionLevel = qtls.EncryptionLevel
|
||||
// Extension is a TLS extension
|
||||
Extension = qtls.Extension
|
||||
// ExtraConfig is the qtls.ExtraConfig
|
||||
ExtraConfig = qtls.ExtraConfig
|
||||
// RecordLayer is a qtls RecordLayer.
|
||||
RecordLayer = qtls.RecordLayer
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake = qtls.EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT = qtls.Encryption0RTT
|
||||
// EncryptionApplication is the application data encryption level
|
||||
EncryptionApplication = qtls.EncryptionApplication
|
||||
)
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return qtls.AEADAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Client(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
// Server returns a new TLS server side connection.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Server(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
func GetConnectionState(conn *Conn) ConnectionState {
|
||||
return conn.ConnectionStateWith0RTT()
|
||||
}
|
||||
|
||||
// ToTLSConnectionState extracts the tls.ConnectionState
|
||||
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
|
||||
return cs.ConnectionState
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
|
||||
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
|
||||
val := cipherSuiteTLS13ByID(id)
|
||||
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
return &qtls.CipherSuiteTLS13{
|
||||
ID: cs.ID,
|
||||
KeyLen: cs.KeyLen,
|
||||
AEAD: cs.AEAD,
|
||||
Hash: cs.Hash,
|
||||
}
|
||||
}
|
||||
|
||||
//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-19.cipherSuitesTLS13
|
||||
var cipherSuitesTLS13 []unsafe.Pointer
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13
|
||||
var defaultCipherSuitesTLS13 []uint16
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13NoAES
|
||||
var defaultCipherSuitesTLS13NoAES []uint16
|
||||
|
||||
var cipherSuitesModified bool
|
||||
|
||||
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
|
||||
// such that it only contains the cipher suite with the chosen id.
|
||||
// The reset function returned resets them back to the original value.
|
||||
func SetCipherSuite(id uint16) (reset func()) {
|
||||
if cipherSuitesModified {
|
||||
panic("cipher suites modified multiple times without resetting")
|
||||
}
|
||||
cipherSuitesModified = true
|
||||
|
||||
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
|
||||
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
|
||||
}
|
||||
defaultCipherSuitesTLS13 = []uint16{id}
|
||||
defaultCipherSuitesTLS13NoAES = []uint16{id}
|
||||
|
||||
return func() {
|
||||
cipherSuitesTLS13 = origCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
|
||||
cipherSuitesModified = false
|
||||
}
|
||||
}
|
|
@ -1,101 +1,97 @@
|
|||
//go:build go1.20
|
||||
//go:build go1.20 && !go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
"github.com/quic-go/qtls-go1-20"
|
||||
)
|
||||
|
||||
type (
|
||||
// Alert is a TLS alert
|
||||
Alert = qtls.Alert
|
||||
// A Certificate is qtls.Certificate.
|
||||
Certificate = qtls.Certificate
|
||||
// CertificateRequestInfo contains information about a certificate request.
|
||||
CertificateRequestInfo = qtls.CertificateRequestInfo
|
||||
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
|
||||
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
|
||||
// ClientHelloInfo contains information about a ClientHello.
|
||||
ClientHelloInfo = qtls.ClientHelloInfo
|
||||
// ClientSessionCache is a cache used for session resumption.
|
||||
ClientSessionCache = qtls.ClientSessionCache
|
||||
// ClientSessionState is a state needed for session resumption.
|
||||
ClientSessionState = qtls.ClientSessionState
|
||||
// A Config is a qtls.Config.
|
||||
Config = qtls.Config
|
||||
// A Conn is a qtls.Conn.
|
||||
Conn = qtls.Conn
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
ConnectionState = qtls.ConnectionStateWith0RTT
|
||||
// EncryptionLevel is the encryption level of a message.
|
||||
EncryptionLevel = qtls.EncryptionLevel
|
||||
// Extension is a TLS extension
|
||||
Extension = qtls.Extension
|
||||
// ExtraConfig is the qtls.ExtraConfig
|
||||
ExtraConfig = qtls.ExtraConfig
|
||||
// RecordLayer is a qtls RecordLayer.
|
||||
RecordLayer = qtls.RecordLayer
|
||||
QUICConn = qtls.QUICConn
|
||||
QUICConfig = qtls.QUICConfig
|
||||
QUICEvent = qtls.QUICEvent
|
||||
QUICEventKind = qtls.QUICEventKind
|
||||
QUICEncryptionLevel = qtls.QUICEncryptionLevel
|
||||
AlertError = qtls.AlertError
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake = qtls.EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT = qtls.Encryption0RTT
|
||||
// EncryptionApplication is the application data encryption level
|
||||
EncryptionApplication = qtls.EncryptionApplication
|
||||
QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial
|
||||
QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly
|
||||
QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake
|
||||
QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication
|
||||
)
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return qtls.AEADAESGCMTLS13(key, fixedNonce)
|
||||
const (
|
||||
QUICNoEvent = qtls.QUICNoEvent
|
||||
QUICSetReadSecret = qtls.QUICSetReadSecret
|
||||
QUICSetWriteSecret = qtls.QUICSetWriteSecret
|
||||
QUICWriteData = qtls.QUICWriteData
|
||||
QUICTransportParameters = qtls.QUICTransportParameters
|
||||
QUICTransportParametersRequired = qtls.QUICTransportParametersRequired
|
||||
QUICRejectedEarlyData = qtls.QUICRejectedEarlyData
|
||||
QUICHandshakeDone = qtls.QUICHandshakeDone
|
||||
)
|
||||
|
||||
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) {
|
||||
qtls.InitSessionTicketKeys(conf.TLSConfig)
|
||||
conf.TLSConfig = conf.TLSConfig.Clone()
|
||||
conf.TLSConfig.MinVersion = tls.VersionTLS13
|
||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
||||
Enable0RTT: enable0RTT,
|
||||
Accept0RTT: accept0RTT,
|
||||
GetAppDataForSessionTicket: getDataForSessionTicket,
|
||||
}
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Client(conn, config, extraConfig)
|
||||
func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) {
|
||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
||||
GetAppDataForSessionState: getDataForSessionState,
|
||||
SetAppDataFromSessionState: setDataFromSessionState,
|
||||
}
|
||||
}
|
||||
|
||||
// Server returns a new TLS server side connection.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Server(conn, config, extraConfig)
|
||||
func QUICServer(config *QUICConfig) *QUICConn {
|
||||
return qtls.QUICServer(config)
|
||||
}
|
||||
|
||||
func GetConnectionState(conn *Conn) ConnectionState {
|
||||
return conn.ConnectionStateWith0RTT()
|
||||
func QUICClient(config *QUICConfig) *QUICConn {
|
||||
return qtls.QUICClient(config)
|
||||
}
|
||||
|
||||
// ToTLSConnectionState extracts the tls.ConnectionState
|
||||
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
|
||||
return cs.ConnectionState
|
||||
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case protocol.EncryptionInitial:
|
||||
return qtls.QUICEncryptionLevelInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return qtls.QUICEncryptionLevelHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
return qtls.QUICEncryptionLevelApplication
|
||||
case protocol.Encryption0RTT:
|
||||
return qtls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-20.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
|
||||
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
|
||||
val := cipherSuiteTLS13ByID(id)
|
||||
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
return &qtls.CipherSuiteTLS13{
|
||||
ID: cs.ID,
|
||||
KeyLen: cs.KeyLen,
|
||||
AEAD: cs.AEAD,
|
||||
Hash: cs.Hash,
|
||||
func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel {
|
||||
switch e {
|
||||
case qtls.QUICEncryptionLevelInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
return protocol.EncryptionHandshake
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
return protocol.Encryption1RTT
|
||||
case qtls.QUICEncryptionLevelEarly:
|
||||
return protocol.Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
28
internal/qtls/go120_test.go
Normal file
28
internal/qtls/go120_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
//go:build !go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
"github.com/quic-go/qtls-go1-20"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Go 1.20", func() {
|
||||
It("converts to qtls.EncryptionLevel", func() {
|
||||
Expect(ToTLSEncryptionLevel(protocol.EncryptionInitial)).To(Equal(qtls.QUICEncryptionLevelInitial))
|
||||
Expect(ToTLSEncryptionLevel(protocol.EncryptionHandshake)).To(Equal(qtls.QUICEncryptionLevelHandshake))
|
||||
Expect(ToTLSEncryptionLevel(protocol.Encryption1RTT)).To(Equal(qtls.QUICEncryptionLevelApplication))
|
||||
Expect(ToTLSEncryptionLevel(protocol.Encryption0RTT)).To(Equal(qtls.QUICEncryptionLevelEarly))
|
||||
})
|
||||
|
||||
It("converts from qtls.EncryptionLevel", func() {
|
||||
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelInitial)).To(Equal(protocol.EncryptionInitial))
|
||||
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelHandshake)).To(Equal(protocol.EncryptionHandshake))
|
||||
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelApplication)).To(Equal(protocol.Encryption1RTT))
|
||||
Expect(FromTLSEncryptionLevel(qtls.QUICEncryptionLevelEarly)).To(Equal(protocol.Encryption0RTT))
|
||||
})
|
||||
})
|
|
@ -2,4 +2,153 @@
|
|||
|
||||
package qtls
|
||||
|
||||
var _ int = "The version of quic-go you're using can't be built on Go 1.21 yet. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type (
|
||||
QUICConn = tls.QUICConn
|
||||
QUICConfig = tls.QUICConfig
|
||||
QUICEvent = tls.QUICEvent
|
||||
QUICEventKind = tls.QUICEventKind
|
||||
QUICEncryptionLevel = tls.QUICEncryptionLevel
|
||||
AlertError = tls.AlertError
|
||||
)
|
||||
|
||||
const (
|
||||
QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial
|
||||
QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly
|
||||
QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake
|
||||
QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication
|
||||
)
|
||||
|
||||
const (
|
||||
QUICNoEvent = tls.QUICNoEvent
|
||||
QUICSetReadSecret = tls.QUICSetReadSecret
|
||||
QUICSetWriteSecret = tls.QUICSetWriteSecret
|
||||
QUICWriteData = tls.QUICWriteData
|
||||
QUICTransportParameters = tls.QUICTransportParameters
|
||||
QUICTransportParametersRequired = tls.QUICTransportParametersRequired
|
||||
QUICRejectedEarlyData = tls.QUICRejectedEarlyData
|
||||
QUICHandshakeDone = tls.QUICHandshakeDone
|
||||
)
|
||||
|
||||
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
|
||||
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
|
||||
|
||||
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) {
|
||||
conf := qconf.TLSConfig
|
||||
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
|
||||
conf = conf.Clone()
|
||||
conf.MinVersion = tls.VersionTLS13
|
||||
qconf.TLSConfig = conf
|
||||
|
||||
// add callbacks to save transport parameters into the session ticket
|
||||
origWrapSession := conf.WrapSession
|
||||
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
|
||||
// Add QUIC transport parameters if this is a 0-RTT packet.
|
||||
// TODO(#3853): also save the RTT for non-0-RTT tickets
|
||||
if state.EarlyData {
|
||||
state.Extra = append(state.Extra, addExtraPrefix(getData()))
|
||||
}
|
||||
if origWrapSession != nil {
|
||||
return origWrapSession(cs, state)
|
||||
}
|
||||
b, err := conf.EncryptTicket(cs, state)
|
||||
return b, err
|
||||
}
|
||||
origUnwrapSession := conf.UnwrapSession
|
||||
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
|
||||
// However, using 0-RTT is only possible with the first session ticket.
|
||||
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
|
||||
var unwrapCount int
|
||||
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
|
||||
unwrapCount++
|
||||
var state *tls.SessionState
|
||||
var err error
|
||||
if origUnwrapSession != nil {
|
||||
state, err = origUnwrapSession(identity, connState)
|
||||
} else {
|
||||
state, err = conf.DecryptTicket(identity, connState)
|
||||
}
|
||||
if err != nil || state == nil {
|
||||
return nil, err
|
||||
}
|
||||
if state.EarlyData {
|
||||
extra := findExtraData(state.Extra)
|
||||
if unwrapCount == 1 && extra != nil { // first session ticket
|
||||
state.EarlyData = accept0RTT(extra)
|
||||
} else { // subsequent session ticket, can't be used for 0-RTT
|
||||
state.EarlyData = false
|
||||
}
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) {
|
||||
conf := qconf.TLSConfig
|
||||
if conf.ClientSessionCache != nil {
|
||||
origCache := conf.ClientSessionCache
|
||||
conf.ClientSessionCache = &clientSessionCache{
|
||||
wrapped: origCache,
|
||||
getData: getData,
|
||||
setData: setData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case protocol.EncryptionInitial:
|
||||
return tls.QUICEncryptionLevelInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return tls.QUICEncryptionLevelHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
return tls.QUICEncryptionLevelApplication
|
||||
case protocol.Encryption0RTT:
|
||||
return tls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
|
||||
switch e {
|
||||
case tls.QUICEncryptionLevelInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
return protocol.EncryptionHandshake
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
return protocol.Encryption1RTT
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
return protocol.Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
const extraPrefix = "quic-go1"
|
||||
|
||||
func addExtraPrefix(b []byte) []byte {
|
||||
return append([]byte(extraPrefix), b...)
|
||||
}
|
||||
|
||||
func findExtraData(extras [][]byte) []byte {
|
||||
prefix := []byte(extraPrefix)
|
||||
for _, extra := range extras {
|
||||
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
|
||||
continue
|
||||
}
|
||||
return extra[len(prefix):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
55
internal/qtls/go121_test.go
Normal file
55
internal/qtls/go121_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Go 1.21", func() {
|
||||
It("converts to tls.EncryptionLevel", func() {
|
||||
Expect(ToTLSEncryptionLevel(protocol.EncryptionInitial)).To(Equal(tls.QUICEncryptionLevelInitial))
|
||||
Expect(ToTLSEncryptionLevel(protocol.EncryptionHandshake)).To(Equal(tls.QUICEncryptionLevelHandshake))
|
||||
Expect(ToTLSEncryptionLevel(protocol.Encryption1RTT)).To(Equal(tls.QUICEncryptionLevelApplication))
|
||||
Expect(ToTLSEncryptionLevel(protocol.Encryption0RTT)).To(Equal(tls.QUICEncryptionLevelEarly))
|
||||
})
|
||||
|
||||
It("converts from tls.EncryptionLevel", func() {
|
||||
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelInitial)).To(Equal(protocol.EncryptionInitial))
|
||||
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelHandshake)).To(Equal(protocol.EncryptionHandshake))
|
||||
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelApplication)).To(Equal(protocol.Encryption1RTT))
|
||||
Expect(FromTLSEncryptionLevel(tls.QUICEncryptionLevelEarly)).To(Equal(protocol.Encryption0RTT))
|
||||
})
|
||||
|
||||
Context("setting up a tls.Config for the client", func() {
|
||||
It("sets up a session cache if there's one present on the config", func() {
|
||||
csc := tls.NewLRUClientSessionCache(1)
|
||||
conf := &QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}}
|
||||
SetupConfigForClient(conf, nil, nil)
|
||||
Expect(conf.TLSConfig.ClientSessionCache).ToNot(BeNil())
|
||||
Expect(conf.TLSConfig.ClientSessionCache).ToNot(Equal(csc))
|
||||
})
|
||||
|
||||
It("doesn't set up a session cache if there's none present on the config", func() {
|
||||
conf := &QUICConfig{TLSConfig: &tls.Config{}}
|
||||
SetupConfigForClient(conf, nil, nil)
|
||||
Expect(conf.TLSConfig.ClientSessionCache).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("setting up a tls.Config for the server", func() {
|
||||
It("sets the minimum TLS version to TLS 1.3", func() {
|
||||
orig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
conf := &QUICConfig{TLSConfig: orig}
|
||||
SetupConfigForServer(conf, false, nil, nil)
|
||||
Expect(conf.TLSConfig.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
// check that the original config wasn't modified
|
||||
Expect(orig.MinVersion).To(BeEquivalentTo(tls.VersionTLS12))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,4 +1,4 @@
|
|||
//go:build !go1.19
|
||||
//go:build !go1.20
|
||||
|
||||
package qtls
|
||||
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("qtls wrapper", func() {
|
||||
It("gets cipher suites", func() {
|
||||
for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} {
|
||||
cs := CipherSuiteTLS13ByID(id)
|
||||
Expect(cs.ID).To(Equal(id))
|
||||
}
|
||||
})
|
||||
})
|
1
internal/testdata/cert.go
vendored
1
internal/testdata/cert.go
vendored
|
@ -31,6 +31,7 @@ func GetTLSConfig() *tls.Config {
|
|||
panic(err)
|
||||
}
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,10 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
|
|||
}
|
||||
|
||||
// HandleMessage mocks base method.
|
||||
func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||
func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
|
@ -156,6 +156,8 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
|
|||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
|
||||
}
|
||||
|
||||
|
@ -172,6 +174,8 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue