diff --git a/client.go b/client.go index 2357c120..cc8c1bd1 100644 --- a/client.go +++ b/client.go @@ -57,11 +57,11 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi if err != nil { return nil, err } - dl, err := setupTransport(udpConn, tlsConf, true) + tr, err := setupTransport(udpConn, tlsConf, true) if err != nil { return nil, err } - return dl.Dial(ctx, udpAddr, tlsConf, conf) + return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false) } // DialAddrEarly establishes a new 0-RTT QUIC connection to a server. @@ -75,13 +75,13 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf * if err != nil { return nil, err } - dl, err := setupTransport(udpConn, tlsConf, true) + tr, err := setupTransport(udpConn, tlsConf, true) if err != nil { return nil, err } - conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf) + conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true) if err != nil { - dl.Close() + tr.Close() return nil, err } return conn, nil @@ -166,12 +166,6 @@ func dial( } func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err diff --git a/client_test.go b/client_test.go index 9ccf27de..b12fec93 100644 --- a/client_test.go +++ b/client_test.go @@ -13,10 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type nullMultiplexer struct{} diff --git a/connection.go b/connection.go index 955f8f8a..270f6ec0 100644 --- a/connection.go +++ b/connection.go @@ -244,7 +244,7 @@ var newConnection = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(), + oneRTTStream: newCryptoStream(true), perspective: protocol.PerspectiveServer, tracer: tracer, logger: logger, @@ -394,8 +394,7 @@ var newClientConnection = func( ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream() - + oneRTTStream := newCryptoStream(true) params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -453,8 +452,8 @@ var newClientConnection = func( } func (s *connection) preSetup() { - s.initialStream = newCryptoStream() - s.handshakeStream = newCryptoStream() + s.initialStream = newCryptoStream(false) + s.handshakeStream = newCryptoStream(false) s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) diff --git a/connection_test.go b/connection_test.go index 95f1b346..de0983ea 100644 --- a/connection_test.go +++ b/connection_test.go @@ -26,10 +26,9 @@ import ( "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func areConnsRunning() bool { @@ -2703,8 +2702,9 @@ var _ = Describe("Client Connection", func() { Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) }) - It("it closes when no matching version is found", func() { + It("closes when no matching version is found", func() { errChan := make(chan error, 1) + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -2712,7 +2712,6 @@ var _ = Describe("Client Connection", func() { errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) - packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) gomock.InOrder( tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { diff --git a/crypto_stream.go b/crypto_stream.go index 4ad097ce..0c991089 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -30,10 +30,17 @@ type cryptoStreamImpl struct { writeOffset protocol.ByteCount writeBuf []byte + + // Reassemble TLS handshake messages before returning them from GetCryptoData. + // This is only needed because crypto/tls doesn't correctly handle post-handshake messages. + onlyCompleteMsg bool } -func newCryptoStream() cryptoStream { - return &cryptoStreamImpl{queue: newFrameSorter()} +func newCryptoStream(onlyCompleteMsg bool) cryptoStream { + return &cryptoStreamImpl{ + queue: newFrameSorter(), + onlyCompleteMsg: onlyCompleteMsg, + } } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { @@ -71,6 +78,20 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { + if s.onlyCompleteMsg { + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg + } + b := s.msgBuf s.msgBuf = nil return b diff --git a/crypto_stream_test.go b/crypto_stream_test.go index c875bcb5..af6ad986 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/rand" "fmt" "github.com/refraction-networking/uquic/internal/protocol" @@ -15,7 +16,7 @@ var _ = Describe("Crypto Stream", func() { var str cryptoStream BeforeEach(func() { - str = newCryptoStream() + str = newCryptoStream(false) }) Context("handling incoming data", func() { @@ -137,4 +138,23 @@ var _ = Describe("Crypto Stream", func() { Expect(f.Data).To(Equal([]byte("bar"))) }) }) + + It("reassembles data", func() { + str = newCryptoStream(true) + data := make([]byte, 1337) + l := len(data) - 4 + data[1] = uint8(l >> 16) + data[2] = uint8(l >> 8) + data[3] = uint8(l) + rand.Read(data[4:]) + + for i, b := range data { + Expect(str.GetCryptoData()).To(BeEmpty()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(i), + Data: []byte{b}, + })).To(Succeed()) + } + Expect(str.GetCryptoData()).To(Equal(data)) + }) }) diff --git a/framer_test.go b/framer_test.go index 9a6b0777..1c4944da 100644 --- a/framer_test.go +++ b/framer_test.go @@ -8,10 +8,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Framer", func() { diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index f017e1f0..dd06a5c1 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -85,33 +85,6 @@ func (m messageType) String() string { } } -func appendSuites(suites []uint16, rand uint8) []uint16 { - const ( - s1 = tls.TLS_AES_128_GCM_SHA256 - s2 = tls.TLS_AES_256_GCM_SHA384 - s3 = tls.TLS_CHACHA20_POLY1305_SHA256 - ) - switch rand % 4 { - default: - return suites - case 1: - return append(suites, s1) - case 2: - return append(suites, s2) - case 3: - return append(suites, s3) - } -} - -// consumes 2 bits -func getSuites(rand uint8) []uint16 { - suites := make([]uint16, 0, 3) - for i := 1; i <= 3; i++ { - suites = appendSuites(suites, rand>>i%4) - } - return suites -} - // consumes 3 bits func getClientAuth(rand uint8) tls.ClientAuthType { switch rand { @@ -148,6 +121,7 @@ func getTransportParameters(seed uint8) *wire.TransportParameters { const maxVarInt = math.MaxUint64 / 4 r := mrand.New(mrand.NewSource(int64(seed))) return &wire.TransportParameters{ + ActiveConnectionIDLimit: 2, InitialMaxData: protocol.ByteCount(r.Int63n(maxVarInt)), InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Int63n(maxVarInt)), InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Int63n(maxVarInt)), @@ -207,14 +181,26 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. SessionTicketKey: sessionTicketKey, } + // This sets the cipher suite for both client and server. + // The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server. + resetCipherSuite := func() {} + switch (runConfig[0] >> 6) % 4 { + case 0: + resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256) + case 1: + resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384) + case 3: + resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) + default: + } + defer resetCipherSuite() + enable0RTTClient := helper.NthBit(runConfig[0], 0) enable0RTTServer := helper.NthBit(runConfig[0], 1) sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3) sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4) sendSessionTicket := helper.NthBit(runConfig[0], 5) - clientConf.CipherSuites = getSuites(runConfig[0] >> 6) serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) - serverConf.CipherSuites = getSuites(runConfig[1] >> 6) serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) if helper.NthBit(runConfig[2], 0) { clientConf.RootCAs = x509.NewCertPool() @@ -303,6 +289,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. if err := client.StartHandshake(); err != nil { log.Fatal(err) } + defer client.Close() server := handshake.NewCryptoSetupServer( protocol.ConnectionID{}, @@ -319,12 +306,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. if err := server.StartHandshake(); err != nil { log.Fatal(err) } + defer server.Close() var clientHandshakeComplete, serverHandshakeComplete bool for { + var processedEvent bool clientLoop: for { - var processedEvent bool ev := client.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { @@ -335,11 +323,16 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. break clientLoop case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: msg := ev.Data + encLevel := protocol.EncryptionInitial + if ev.Kind == handshake.EventWriteHandshakeData { + encLevel = protocol.EncryptionHandshake + } if msg[0] == messageToReplace { fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) msg = data + encLevel = messageToReplaceEncLevel } - if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil { + if err := server.HandleMessage(msg, encLevel); err != nil { return 1 } case handshake.EventHandshakeComplete: @@ -348,9 +341,9 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. processedEvent = true } + processedEvent = false serverLoop: for { - var processedEvent bool ev := server.NextEvent() //nolint:exhaustive // only need to process a few events switch ev.Kind { @@ -360,12 +353,17 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. } break serverLoop case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: + encLevel := protocol.EncryptionInitial + if ev.Kind == handshake.EventWriteHandshakeData { + encLevel = protocol.EncryptionHandshake + } msg := ev.Data if msg[0] == messageToReplace { fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) msg = data + encLevel = messageToReplaceEncLevel } - if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil { + if err := client.HandleMessage(msg, encLevel); err != nil { return 1 } case handshake.EventHandshakeComplete: @@ -410,10 +408,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. } client.HandleMessage(ticket, protocol.Encryption1RTT) } + if sendPostHandshakeMessageToClient { + fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel) client.HandleMessage(data, messageToReplaceEncLevel) } if sendPostHandshakeMessageToServer { + fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel) server.HandleMessage(data, messageToReplaceEncLevel) } diff --git a/go.sum b/go.sum index 151f2f57..a1f36fcd 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= @@ -133,8 +131,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= +go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -200,7 +199,6 @@ golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/http3/body_test.go b/http3/body_test.go index 0f64bc76..c5f9bb87 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -6,9 +6,9 @@ import ( quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Response Body", func() { diff --git a/http3/client_test.go b/http3/client_test.go index 5b56a1fb..28504f30 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -19,11 +19,11 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Client", func() { diff --git a/http3/error_codes.go b/http3/error_codes.go index 515c675e..86e27ff7 100644 --- a/http3/error_codes.go +++ b/http3/error_codes.go @@ -26,7 +26,7 @@ const ( ErrCodeMessageError ErrCode = 0x10e ErrCodeConnectError ErrCode = 0x10f ErrCodeVersionFallback ErrCode = 0x110 - ErrCodeDatagramError ErrCode = 0x4a1268 + ErrCodeDatagramError ErrCode = 0x33 ) func (e ErrCode) String() string { diff --git a/http3/frames.go b/http3/frames.go index 7897b86f..7f6d0fe8 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -88,7 +88,7 @@ func (f *headersFrame) Append(b []byte) []byte { return quicvarint.Append(b, f.Length) } -const settingDatagram = 0xffd277 +const settingDatagram = 0x33 type settingsFrame struct { Datagram bool diff --git a/http3/http3_suite_test.go b/http3/http3_suite_test.go index 56b2108f..b34003b9 100644 --- a/http3/http3_suite_test.go +++ b/http3/http3_suite_test.go @@ -6,10 +6,9 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestHttp3(t *testing.T) { diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index 4f87ee55..5e0ff514 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -7,9 +7,9 @@ import ( quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func getDataFrame(data []byte) []byte { diff --git a/http3/mock_quic_early_listener_test.go b/http3/mock_quic_early_listener_test.go index 0e7cf685..b43b3550 100644 --- a/http3/mock_quic_early_listener_test.go +++ b/http3/mock_quic_early_listener_test.go @@ -9,7 +9,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" quic "github.com/refraction-networking/uquic" ) diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go index f9a82130..6550b255 100644 --- a/http3/mock_roundtripcloser_test.go +++ b/http3/mock_roundtripcloser_test.go @@ -8,7 +8,7 @@ import ( http "net/http" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockRoundTripCloser is a mock of RoundTripCloser interface. diff --git a/http3/mockgen.go b/http3/mockgen.go index a8185a97..cf4b917a 100644 --- a/http3/mockgen.go +++ b/http3/mockgen.go @@ -2,7 +2,7 @@ package http3 -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser" type RoundTripCloser = roundTripCloser -//go:generate sh -c "go run github.com/golang/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener" diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 9de46f90..e590efa1 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -8,8 +8,8 @@ import ( mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" + "go.uber.org/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" diff --git a/http3/response_writer.go b/http3/response_writer.go index 2262f16c..8eb592a9 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -15,19 +15,61 @@ import ( "github.com/quic-go/qpack" ) +// The maximum length of an encoded HTTP/3 frame header is 16: +// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) +const frameHeaderLen = 16 + +// headerWriter wraps the stream, so that the first Write call flushes the header to the stream +type headerWriter struct { + str quic.Stream + header http.Header + status int // status code passed to WriteHeader + written bool + + logger utils.Logger +} + +// writeHeader encodes and flush header to the stream +func (hw *headerWriter) writeHeader() error { + var headers bytes.Buffer + enc := qpack.NewEncoder(&headers) + enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}) + + for k, v := range hw.header { + for index := range v { + enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) + } + } + + buf := make([]byte, 0, frameHeaderLen+headers.Len()) + buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) + hw.logger.Infof("Responding with %d", hw.status) + buf = append(buf, headers.Bytes()...) + + _, err := hw.str.Write(buf) + return err +} + +// first Write will trigger flushing header +func (hw *headerWriter) Write(p []byte) (int, error) { + if !hw.written { + if err := hw.writeHeader(); err != nil { + return 0, err + } + hw.written = true + } + return hw.str.Write(p) +} + type responseWriter struct { + *headerWriter conn quic.Connection - str quic.Stream bufferedStr *bufio.Writer buf []byte - header http.Header - status int // status code passed to WriteHeader headerWritten bool contentLen int64 // if handler set valid Content-Length header numWritten int64 // bytes written - - logger utils.Logger } var ( @@ -37,13 +79,16 @@ var ( ) func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { + hw := &headerWriter{ + str: str, + header: http.Header{}, + logger: logger, + } return &responseWriter{ - header: http.Header{}, - buf: make([]byte, 16), - conn: conn, - str: str, - bufferedStr: bufio.NewWriter(str), - logger: logger, + headerWriter: hw, + buf: make([]byte, frameHeaderLen), + conn: conn, + bufferedStr: bufio.NewWriter(hw), } } @@ -83,27 +128,8 @@ func (w *responseWriter) WriteHeader(status int) { } w.status = status - var headers bytes.Buffer - enc := qpack.NewEncoder(&headers) - enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - - for k, v := range w.header { - for index := range v { - enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) - } - } - - w.buf = w.buf[:0] - w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf) - w.logger.Infof("Responding with %d", status) - if _, err := w.bufferedStr.Write(w.buf); err != nil { - w.logger.Errorf("could not write headers frame: %s", err.Error()) - } - if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { - w.logger.Errorf("could not write header frame payload: %s", err.Error()) - } if !w.headerWritten { - w.Flush() + w.writeHeader() } } @@ -146,6 +172,15 @@ func (w *responseWriter) Write(p []byte) (int, error) { } func (w *responseWriter) FlushError() error { + if !w.headerWritten { + w.WriteHeader(http.StatusOK) + } + if !w.written { + if err := w.writeHeader(); err != nil { + return err + } + w.written = true + } return w.bufferedStr.Flush() } diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 8563ae8a..88c18a0a 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -9,11 +9,11 @@ import ( mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Response Writer", func() { diff --git a/http3/roundtrip.go b/http3/roundtrip.go index df8bc409..3ef5ddd2 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -52,7 +52,7 @@ type RoundTripper struct { // Enable support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. + // See https://datatracker.ietf.org/doc/html/rfc9297. EnableDatagrams bool // Additional HTTP/3 settings. diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 279308fd..e685aa61 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -14,9 +14,9 @@ import ( "github.com/refraction-networking/uquic/internal/qerr" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type mockBody struct { diff --git a/http3/server.go b/http3/server.go index b32e80ed..20a0323a 100644 --- a/http3/server.go +++ b/http3/server.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "runtime" + "strconv" "strings" "sync" "time" @@ -33,12 +34,8 @@ var ( } ) -const ( - // NextProtoH3Draft29 is the ALPN protocol negotiated during the TLS handshake, for QUIC draft 29. - NextProtoH3Draft29 = "h3-29" - // NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. - NextProtoH3 = "h3" -) +// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. +const NextProtoH3 = "h3" // StreamType is the stream type of a unidirectional stream. type StreamType uint64 @@ -178,7 +175,7 @@ type Server struct { // EnableDatagrams enables support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07. + // See https://datatracker.ietf.org/doc/html/rfc9297. EnableDatagrams bool // MaxHeaderBytes controls the maximum number of bytes the server will @@ -651,7 +648,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q // only write response when there is no panic if !panicked { - r.WriteHeader(http.StatusOK) + // response not written to the client yet, set Content-Length + if !r.written { + if _, haveCL := r.header["Content-Length"]; !haveCL { + r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) + } + } r.Flush() } // If the EOF was read by the handler, CancelRead() is a no-op. diff --git a/http3/server_test.go b/http3/server_test.go index b33c0f59..5ab313de 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -21,8 +21,8 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" - "github.com/golang/mock/gomock" "github.com/quic-go/qpack" + "go.uber.org/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -181,6 +181,47 @@ var _ = Describe("Server", func() { Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) }) + It("sets Content-Length when the handler doesn't flush to the client", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) + // status, content-length, date, content-type + Expect(hfs).To(HaveLen(4)) + }) + + It("not sets Content-Length when the handler flushes to the client", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + // force flush + w.(http.Flusher).Flush() + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + // status, date, content-type + Expect(hfs).To(HaveLen(3)) + }) + It("handles a aborting handler", func() { s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic(http.ErrAbortHandler) diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 1336cd7c..e24232c0 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -39,8 +39,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -136,8 +134,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk= -github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= +github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -173,10 +171,11 @@ github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cb github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= +go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -195,7 +194,7 @@ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= @@ -217,7 +216,6 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -261,9 +259,7 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -309,7 +305,7 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 9def9361..6ab814a8 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -201,6 +201,8 @@ var _ = Describe("Handshake tests", func() { Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) + var certErr *tls.CertificateVerificationError + Expect(errors.As(transportErr, &certErr)).To(BeTrue()) }) It("fails the handshake if the client fails to provide the requested client cert", func() { @@ -452,7 +454,7 @@ var _ = Describe("Handshake tests", func() { It("rejects invalid Retry token with the INVALID_TOKEN error", func() { serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - serverConfig.MaxRetryTokenAge = time.Nanosecond + serverConfig.MaxRetryTokenAge = -time.Second server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index d0474ad3..698f3160 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -42,7 +42,7 @@ var _ = Describe("HTTP tests", func() { rt *http3.RoundTripper server *http3.Server stoppedServing chan struct{} - port string + port int ) BeforeEach(func() { @@ -93,7 +93,7 @@ var _ = Describe("HTTP tests", func() { Expect(err).NotTo(HaveOccurred()) conn, err := net.ListenUDP("udp", addr) Expect(err).NotTo(HaveOccurred()) - port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port) + port = conn.LocalAddr().(*net.UDPAddr).Port stoppedServing = make(chan struct{}) @@ -120,7 +120,7 @@ var _ = Describe("HTTP tests", func() { }) It("downloads a hello", func() { - resp, err := client.Get("https://localhost:" + port + "/hello") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) @@ -128,13 +128,25 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("sets content-length for small response", func() { + mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Write([]byte("foobar")) + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar")))) + }) + It("requests to different servers with the same udpconn", func() { - resp, err := client.Get("https://localhost:" + port + "/remoteAddr") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) addr1 := resp.Header.Get("X-RemoteAddr") Expect(addr1).ToNot(Equal("")) - resp, err = client.Get("https://127.0.0.1:" + port + "/remoteAddr") + resp, err = client.Get(fmt.Sprintf("https://127.0.0.1:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) addr2 := resp.Header.Get("X-RemoteAddr") @@ -146,7 +158,7 @@ var _ = Describe("HTTP tests", func() { group, ctx := errgroup.WithContext(context.Background()) for i := 0; i < 2; i++ { group.Go(func() error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil) Expect(err).ToNot(HaveOccurred()) resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) @@ -172,7 +184,7 @@ var _ = Describe("HTTP tests", func() { close(handlerCalled) }) - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/headers/request", nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/headers/request", port), nil) Expect(err).ToNot(HaveOccurred()) req.Header.Set("foo", "bar") req.Header.Set("lorem", "ipsum") @@ -189,7 +201,7 @@ var _ = Describe("HTTP tests", func() { w.Header().Set("lorem", "ipsum") }) - resp, err := client.Get("https://localhost:" + port + "/headers/response") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/headers/response", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Header.Get("foo")).To(Equal("bar")) @@ -197,7 +209,7 @@ var _ = Describe("HTTP tests", func() { }) It("downloads a small file", func() { - resp, err := client.Get("https://localhost:" + port + "/prdata") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdata", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) @@ -206,7 +218,7 @@ var _ = Describe("HTTP tests", func() { }) It("downloads a large file", func() { - resp, err := client.Get("https://localhost:" + port + "/prdatalong") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdatalong", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second)) @@ -218,7 +230,7 @@ var _ = Describe("HTTP tests", func() { const num = 150 for i := 0; i < num; i++ { - resp, err := client.Get("https://localhost:" + port + "/hello") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) @@ -231,7 +243,7 @@ var _ = Describe("HTTP tests", func() { const num = 150 for i := 0; i < num; i++ { - resp, err := client.Get("https://localhost:" + port + "/prdata") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/prdata", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Body.Close()).To(Succeed()) @@ -240,7 +252,7 @@ var _ = Describe("HTTP tests", func() { It("posts a small message", func() { resp, err := client.Post( - "https://localhost:"+port+"/echo", + fmt.Sprintf("https://localhost:%d/echo", port), "text/plain", bytes.NewReader([]byte("Hello, world!")), ) @@ -253,7 +265,7 @@ var _ = Describe("HTTP tests", func() { It("uploads a file", func() { resp, err := client.Post( - "https://localhost:"+port+"/echo", + fmt.Sprintf("https://localhost:%d/echo", port), "text/plain", bytes.NewReader(PRData), ) @@ -277,7 +289,7 @@ var _ = Describe("HTTP tests", func() { }) client.Transport.(*http3.RoundTripper).DisableCompression = false - resp, err := client.Get("https://localhost:" + port + "/gzipped/hello") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/gzipped/hello", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) Expect(resp.Uncompressed).To(BeTrue()) @@ -303,7 +315,7 @@ var _ = Describe("HTTP tests", func() { } }) - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/cancel", nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil) Expect(err).ToNot(HaveOccurred()) ctx, cancel := context.WithCancel(context.Background()) req = req.WithContext(ctx) @@ -336,7 +348,7 @@ var _ = Describe("HTTP tests", func() { }) r, w := io.Pipe() - req, err := http.NewRequest("PUT", "https://localhost:"+port+"/echoline", r) + req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("https://localhost:%d/echoline", port), r) Expect(err).ToNot(HaveOccurred()) rsp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) @@ -373,7 +385,7 @@ var _ = Describe("HTTP tests", func() { }() }) - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/httpstreamer", nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/httpstreamer", port), nil) Expect(err).ToNot(HaveOccurred()) rsp, err := client.Transport.(*http3.RoundTripper).RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true}) Expect(err).ToNot(HaveOccurred()) @@ -431,7 +443,11 @@ var _ = Describe("HTTP tests", func() { }) expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) + resp, err := client.Post( + fmt.Sprintf("https://localhost:%d/read-deadline", port), + "text/plain", + neverEnding('a'), + ) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) @@ -453,7 +469,7 @@ var _ = Describe("HTTP tests", func() { expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Get("https://localhost:" + port + "/write-deadline") + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/write-deadline", port)) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 92a23025..b508f9aa 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -2,9 +2,11 @@ package self_test import ( "context" + "crypto/rand" "io" "net" "runtime" + "sync/atomic" "time" . "github.com/onsi/ginkgo/v2" @@ -209,4 +211,67 @@ var _ = Describe("Multiplexing", func() { }) } }) + + It("sends and receives non-QUIC packets", func() { + addr1, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn1, err := net.ListenUDP("udp", addr1) + Expect(err).ToNot(HaveOccurred()) + defer conn1.Close() + tr1 := &quic.Transport{Conn: conn1} + + addr2, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn2, err := net.ListenUDP("udp", addr2) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + tr2 := &quic.Transport{Conn: conn2} + + server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + runServer(server) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var sentPackets, rcvdPackets atomic.Int64 + const packetLen = 128 + // send a non-QUIC packet every 100µs + go func() { + defer GinkgoRecover() + ticker := time.NewTicker(time.Millisecond / 10) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + b := make([]byte, packetLen) + rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet + _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + sentPackets.Add(1) + } + }() + + // receive and count non-QUIC packets + go func() { + defer GinkgoRecover() + for { + b := make([]byte, 1024) + n, addr, err := tr2.ReadNonQUICPacket(ctx, b) + if err != nil { + Expect(err).To(MatchError(context.Canceled)) + return + } + Expect(addr).To(Equal(tr1.Conn.LocalAddr())) + Expect(n).To(Equal(packetLen)) + rcvdPackets.Add(1) + } + }() + dial(tr2, server.Addr()) + Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10)) + Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5)) + }) }) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 3e827cb8..d38ae9a7 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -56,7 +56,7 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) var sessionKey string @@ -71,7 +71,7 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) @@ -85,7 +85,7 @@ var _ = Describe("TLS session resumption", func() { It("doesn't use session resumption, if the config disables it", func() { sConf := getTLSConfig() sConf.SessionTicketsDisabled = true - server, err := quic.ListenAddr("localhost:0", sConf, nil) + server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) defer server.Close() @@ -98,7 +98,7 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) @@ -114,7 +114,55 @@ var _ = Describe("TLS session resumption", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, - nil, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + + serverConn, err = server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + }) + + It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() { + sConf := &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + conf := getTLSConfig() + conf.SessionTicketsDisabled = true + return conf, nil + }, + } + + server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + gets := make(chan string, 100) + puts := make(chan string, 100) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) + tlsConf := getTLSClientConfig() + tlsConf.ClientSessionCache = cache + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Consistently(puts).ShouldNot(Receive()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + serverConn, err := server.Accept(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + + conn, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index 919e9947..101c1ffd 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -802,4 +802,116 @@ var _ = Describe("0-RTT", func() { Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) + + It("sends 0-RTT datagrams", func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + EnableDatagrams: true, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + sentMessage := GeneratePRData(100) + var receivedMessage []byte + received := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(received) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + receivedMessage, err = conn.ReceiveMessage(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(sentMessage)).To(Succeed()) + <-conn.HandshakeComplete() + <-received + + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(receivedMessage).To(Equal(sentMessage)) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(zeroRTTPackets).To(HaveLen(1)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + EnableDatagrams: true, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: false, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.ReceiveMessage(context.Background()) + Expect(err.Error()).To(Equal("datagram support disabled")) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + // the client can temporarily send datagrams but the server doesn't process them. + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + <-conn.HandshakeComplete() + + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 5021e43f..7fe81930 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -939,4 +939,119 @@ var _ = Describe("0-RTT", func() { ) Expect(restored).To(BeTrue()) }) + + It("sends 0-RTT datagrams", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + sentMessage := GeneratePRData(100) + var receivedMessage []byte + received := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(received) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + receivedMessage, err = conn.ReceiveMessage(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(sentMessage)).To(Succeed()) + <-conn.HandshakeComplete() + <-received + + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(receivedMessage).To(Equal(sentMessage)) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(zeroRTTPackets).To(HaveLen(1)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: false, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.ReceiveMessage(context.Background()) + Expect(err.Error()).To(Equal("datagram support disabled")) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + // the client can temporarily send datagrams but the server doesn't process them. + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + <-conn.HandshakeComplete() + + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) }) diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index b89caabd..bef42f28 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net" + "runtime" "runtime/pprof" "strconv" "strings" @@ -68,7 +69,11 @@ var _ = Describe("QUIC Proxy", func() { addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) _, err = net.ListenUDP("udp", addr) - Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort()))) + if runtime.GOOS == "windows" { + Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: Only one usage of each socket address (protocol/network address/port) is normally permitted.", proxy.LocalPort()))) + } else { + Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort()))) + } Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test }) diff --git a/internal/ackhandler/ackhandler_suite_test.go b/internal/ackhandler/ackhandler_suite_test.go index 069aaa79..a0cf3ee1 100644 --- a/internal/ackhandler/ackhandler_suite_test.go +++ b/internal/ackhandler/ackhandler_suite_test.go @@ -3,9 +3,9 @@ package ackhandler import ( "testing" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestCrypto(t *testing.T) { diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index fecb3272..cd34cfdb 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -7,7 +7,7 @@ package ackhandler import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go index b432a0a0..3d2f1082 100644 --- a/internal/ackhandler/mockgen.go +++ b/internal/ackhandler/mockgen.go @@ -2,5 +2,5 @@ package ackhandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index d0377200..a537399b 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -3,14 +3,13 @@ package ackhandler import ( "time" - "github.com/golang/mock/gomock" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Received Packet Handler", func() { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 9b21d1fe..15a09fa2 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -4,8 +4,6 @@ import ( "fmt" "time" - "github.com/golang/mock/gomock" - "github.com/refraction-networking/uquic/internal/mocks" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" @@ -14,6 +12,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type customFrameHandler struct { diff --git a/internal/flowcontrol/flowcontrol_suite_test.go b/internal/flowcontrol/flowcontrol_suite_test.go index 8831296d..6cfe981d 100644 --- a/internal/flowcontrol/flowcontrol_suite_test.go +++ b/internal/flowcontrol/flowcontrol_suite_test.go @@ -3,9 +3,9 @@ package flowcontrol import ( "testing" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestFlowControl(t *testing.T) { diff --git a/internal/handshake/cipher_suite.go b/internal/handshake/cipher_suite.go index 2ed922a8..368a95d4 100644 --- a/internal/handshake/cipher_suite.go +++ b/internal/handshake/cipher_suite.go @@ -31,7 +31,7 @@ func getCipherSuite(id uint16) *cipherSuite { case tls.TLS_CHACHA20_POLY1305_SHA256: return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305} case tls.TLS_AES_256_GCM_SHA384: - return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadAESGCMTLS13} + return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13} default: panic(fmt.Sprintf("unknown cypher suite: %d", id)) } diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 5630b880..f2e6b67d 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "strings" "sync" "sync/atomic" "time" @@ -358,10 +359,15 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte { // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - if h.tlsConf.SessionTicketsDisabled { - return nil, nil - } - if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{h.allow0RTT}); err != nil { + if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. + // We can't check h.tlsConfig here, since the actual config might have been obtained from + // the GetConfigForClient callback. + // See https://github.com/golang/go/issues/62032. + // Once that issue is resolved, this error assertion can be removed. + if strings.Contains(err.Error(), "session ticket keys unavailable") { + return nil, nil + } return nil, err } ev := h.conn.NextEvent() @@ -660,8 +666,9 @@ func (h *cryptoSetup) ConnectionState() ConnectionState { } func wrapError(err error) error { + // alert 80 is an internal error if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { - return qerr.NewLocalCryptoError(uint8(alertErr), err.Error()) + return qerr.NewLocalCryptoError(uint8(alertErr), err) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index f29517cd..fbc82512 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -18,10 +18,9 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) const ( diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index 45ff1c3d..7d82421c 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -7,10 +7,9 @@ import ( tls "github.com/refraction-networking/utls" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestHandshake(t *testing.T) { diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go index a113f024..83c64a7d 100644 --- a/internal/handshake/session_ticket_test.go +++ b/internal/handshake/session_ticket_test.go @@ -17,6 +17,7 @@ var _ = Describe("Session Ticket", func() { InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiRemote: 2, ActiveConnectionIDLimit: 10, + MaxDatagramFrameSize: 20, }, RTT: 1337 * time.Microsecond, } @@ -25,6 +26,7 @@ var _ = Describe("Session Ticket", func() { Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10)) + Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20)) Expect(t.RTT).To(Equal(1337 * time.Microsecond)) }) diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 4072d6dd..4bc1fa86 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -13,9 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Updatable AEAD", func() { diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index a368b364..569b690e 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -8,9 +8,9 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" + gomock "go.uber.org/mock/gomock" ) // MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface. diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 76db968c..37fae0e6 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index c81b7177..a2ba0aab 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index 02d3c3a1..60d87d39 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 61a019a6..5558f973 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" handshake "github.com/refraction-networking/uquic/internal/handshake" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index e03a5897..4a9d1fd7 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -9,11 +9,12 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" utils "github.com/refraction-networking/uquic/internal/utils" wire "github.com/refraction-networking/uquic/internal/wire" logging "github.com/refraction-networking/uquic/logging" + + gomock "go.uber.org/mock/gomock" ) // MockConnectionTracer is a mock of ConnectionTracer interface. diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 42f81660..b479c66c 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -8,10 +8,10 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" logging "github.com/refraction-networking/uquic/logging" + gomock "go.uber.org/mock/gomock" ) // MockTracer is a mock of Tracer interface. diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go index 61d55f27..9552aff7 100644 --- a/internal/mocks/long_header_opener.go +++ b/internal/mocks/long_header_opener.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 6f68d26c..58ed5578 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,19 +1,19 @@ package mocks -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/refraction-networking/uquic/logging Tracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/refraction-networking/uquic/logging ConnectionTracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/refraction-networking/uquic/logging Tracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/refraction-networking/uquic/logging ConnectionTracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" // The following command produces a warning message on OSX, however, it still generates the correct mock file. // See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index 816ca5a3..428c0eb1 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -9,9 +9,9 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" quic "github.com/refraction-networking/uquic" qerr "github.com/refraction-networking/uquic/internal/qerr" + gomock "go.uber.org/mock/gomock" ) // MockEarlyConnection is a mock of EarlyConnection interface. diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index c548883b..97c52c52 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -9,9 +9,9 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" + gomock "go.uber.org/mock/gomock" ) // MockStream is a mock of Stream interface. diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index c4ab76de..4f609b93 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go index 90f4b80f..dfb654bf 100644 --- a/internal/mocks/short_header_sealer.go +++ b/internal/mocks/short_header_sealer.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index ab6dcade..3a16e937 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -7,7 +7,7 @@ package mocks import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go index 5178c62a..9a7c3444 100644 --- a/internal/mocks/tls/client_session_cache.go +++ b/internal/mocks/tls/client_session_cache.go @@ -8,7 +8,7 @@ import ( tls "github.com/refraction-networking/utls" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockClientSessionCache is a mock of ClientSessionCache interface. diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 98bc9ffb..b8104882 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -43,6 +43,21 @@ const ( ECNCE // 11 ) +func (e ECN) String() string { + switch e { + case ECNNon: + return "Not-ECT" + case ECT1: + return "ECT(1)" + case ECT0: + return "ECT(0)" + case ECNCE: + return "CE" + default: + return fmt.Sprintf("invalid ECN value: %d", e) + } +} + // A ByteCount in QUIC type ByteCount int64 diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index d9048f01..e672d31e 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -22,4 +22,12 @@ var _ = Describe("Protocol", func() { Expect(ECN(0b00000001)).To(Equal(ECT1)) Expect(ECN(0b00000011)).To(Equal(ECNCE)) }) + + It("has a string representation for ECN", func() { + Expect(ECNNon.String()).To(Equal("Not-ECT")) + Expect(ECT0.String()).To(Equal("ECT(0)")) + Expect(ECT1.String()).To(Equal("ECT(1)")) + Expect(ECNCE.String()).To(Equal("CE")) + Expect(ECN(42).String()).To(Equal("invalid ECN value: 42")) + }) }) diff --git a/internal/qerr/errors.go b/internal/qerr/errors.go index 4a053ae0..c3d8465b 100644 --- a/internal/qerr/errors.go +++ b/internal/qerr/errors.go @@ -17,15 +17,16 @@ type TransportError struct { FrameType uint64 ErrorCode TransportErrorCode ErrorMessage string + error error // only set for local errors, sometimes } var _ error = &TransportError{} // NewLocalCryptoError create a new TransportError instance for a crypto error -func NewLocalCryptoError(tlsAlert uint8, errorMessage string) *TransportError { +func NewLocalCryptoError(tlsAlert uint8, err error) *TransportError { return &TransportError{ - ErrorCode: 0x100 + TransportErrorCode(tlsAlert), - ErrorMessage: errorMessage, + ErrorCode: 0x100 + TransportErrorCode(tlsAlert), + error: err, } } @@ -35,6 +36,9 @@ func (e *TransportError) Error() string { str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) } msg := e.ErrorMessage + if len(msg) == 0 && e.error != nil { + msg = e.error.Error() + } if len(msg) == 0 { msg = e.ErrorCode.Message() } @@ -48,6 +52,10 @@ func (e *TransportError) Is(target error) bool { return target == net.ErrClosed } +func (e *TransportError) Unwrap() error { + return e.error +} + // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index cfad6156..7bfbc734 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -2,6 +2,7 @@ package qerr import ( "errors" + "fmt" "net" "github.com/refraction-networking/uquic/internal/protocol" @@ -10,6 +11,12 @@ import ( . "github.com/onsi/gomega" ) +type myError int + +var _ error = myError(0) + +func (e myError) Error() string { return fmt.Sprintf("my error %d", e) } + var _ = Describe("QUIC Errors", func() { Context("Transport Errors", func() { It("has a string representation", func() { @@ -41,12 +48,20 @@ var _ = Describe("QUIC Errors", func() { Context("crypto errors", func() { It("has a string representation for errors with a message", func() { - err := NewLocalCryptoError(0x42, "foobar") - Expect(err.Error()).To(Equal("CRYPTO_ERROR 0x142 (local): foobar")) + myErr := myError(1337) + err := NewLocalCryptoError(0x42, myErr) + Expect(err.Error()).To(Equal("CRYPTO_ERROR 0x142 (local): my error 1337")) + }) + + It("unwraps errors", func() { + var myErr myError + err := NewLocalCryptoError(0x42, myError(1337)) + Expect(errors.As(err, &myErr)).To(BeTrue()) + Expect(myErr).To(BeEquivalentTo(1337)) }) It("has a string representation for errors without a message", func() { - err := NewLocalCryptoError(0x2a, "") + err := NewLocalCryptoError(0x2a, nil) Expect(err.Error()).To(Equal("CRYPTO_ERROR 0x12a (local): tls: bad certificate")) }) }) diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go index e8ce652a..bde81e6c 100644 --- a/internal/qtls/qtls_suite_test.go +++ b/internal/qtls/qtls_suite_test.go @@ -3,10 +3,9 @@ package qtls import ( "testing" - gomock "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestQTLS(t *testing.T) { diff --git a/internal/qtls/utls.go b/internal/qtls/utls.go index a056c741..72aa2086 100644 --- a/internal/qtls/utls.go +++ b/internal/qtls/utls.go @@ -10,13 +10,14 @@ import ( ) type ( - QUICConn = tls.QUICConn - UQUICConn = tls.UQUICConn // [UQUIC] - QUICConfig = tls.QUICConfig - QUICEvent = tls.QUICEvent - QUICEventKind = tls.QUICEventKind - QUICEncryptionLevel = tls.QUICEncryptionLevel - AlertError = tls.AlertError + QUICConn = tls.QUICConn + UQUICConn = tls.UQUICConn // [UQUIC] + QUICConfig = tls.QUICConfig + QUICEvent = tls.QUICEvent + QUICEventKind = tls.QUICEventKind + QUICEncryptionLevel = tls.QUICEncryptionLevel + QUICSessionTicketOptions = tls.QUICSessionTicketOptions + AlertError = tls.AlertError ) const ( @@ -166,3 +167,9 @@ func findExtraData(extras [][]byte) []byte { } return nil } + +func SendSessionTicket(c *QUICConn, allow0RTT bool) error { + return c.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: allow0RTT, + }) +} diff --git a/internal/wire/header.go b/internal/wire/header.go index 4060d598..f6025019 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -74,6 +74,10 @@ func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.Arbitra return destConnID, srcConnID, nil } +func IsPotentialQUICPacket(firstByte byte) bool { + return firstByte&0x40 > 0 +} + // IsLongHeaderPacket says if this is a Long Header packet func IsLongHeaderPacket(firstByte byte) bool { return firstByte&0x80 > 0 diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 48ae0923..faa6eaf3 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -503,6 +503,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValueUpTo(int64(protocol.MaxDatagramFrameSize))), } Expect(params.ValidFor0RTT(params)).To(BeTrue()) b := params.MarshalForSessionTicket(nil) @@ -515,6 +516,7 @@ var _ = Describe("Transport Parameters", func() { Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + Expect(tp.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) }) It("rejects the parameters if it can't parse them", func() { @@ -540,6 +542,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: 1000, } BeforeEach(func() { @@ -611,6 +614,16 @@ var _ = Describe("Transport Parameters", func() { p.ActiveConnectionIDLimit = 0 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + + It("accepts the parameters if the MaxDatagramFrameSize was increased", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxDatagramFrameSize reduced", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) }) Context("client checks the parameters after successfully sending 0-RTT data", func() { @@ -623,6 +636,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: 1000, } BeforeEach(func() { @@ -699,6 +713,16 @@ var _ = Describe("Transport Parameters", func() { p.ActiveConnectionIDLimit = saved.ActiveConnectionIDLimit + 1 Expect(p.ValidForUpdate(saved)).To(BeTrue()) }) + + It("rejects the parameters if the MaxDatagramFrameSize reduced", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 + Expect(p.ValidForUpdate(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the MaxDatagramFrameSize increased", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 + Expect(p.ValidForUpdate(saved)).To(BeTrue()) + }) }) }) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 7a24e2ce..116f6049 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -464,6 +464,10 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // max_datagram_frame_size + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } // active_connection_id_limit return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) } @@ -482,6 +486,9 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { + if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { + return false + } return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && @@ -494,6 +501,9 @@ func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { // ValidForUpdate checks that the new transport parameters don't reduce limits after resuming a 0-RTT connection. // It is only used on the client side. func (p *TransportParameters) ValidForUpdate(saved *TransportParameters) bool { + if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { + return false + } return p.ActiveConnectionIDLimit >= saved.ActiveConnectionIDLimit && p.InitialMaxData >= saved.InitialMaxData && p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index cbaa2cb5..3659d103 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -40,7 +40,10 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnec buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) r := make([]byte, 1) _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. - buf.WriteByte(r[0] | 0x80) + // Setting the "QUIC bit" (0x40) is not required by the RFC, + // but it allows clients to demultiplex QUIC with a long list of other protocols. + // See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details. + buf.WriteByte(r[0] | 0xc0) utils.BigEndian.WriteUint32(buf, 0) // version 0 buf.WriteByte(uint8(destConnID.Len())) buf.Write(destConnID.Bytes()) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 529a3234..e47d2c83 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -64,6 +64,7 @@ var _ = Describe("Version Negotiation Packets", func() { versions := []protocol.VersionNumber{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(IsLongHeaderPacket(data[0])).To(BeTrue()) + Expect(data[0] & 0x40).ToNot(BeZero()) v, err := ParseVersion(data) Expect(err).ToNot(HaveOccurred()) Expect(v).To(BeZero()) diff --git a/logging/logging_suite_test.go b/logging/logging_suite_test.go index d37ada48..d808adfe 100644 --- a/logging/logging_suite_test.go +++ b/logging/logging_suite_test.go @@ -3,10 +3,9 @@ package logging import ( "testing" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestLogging(t *testing.T) { diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index 5fad29ce..e5dd38c2 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" utils "github.com/refraction-networking/uquic/internal/utils" wire "github.com/refraction-networking/uquic/internal/wire" diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go index cb13831d..d295c96d 100644 --- a/logging/mock_tracer_test.go +++ b/logging/mock_tracer_test.go @@ -8,7 +8,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/logging/mockgen.go b/logging/mockgen.go index 8b7fab64..66f712bc 100644 --- a/logging/mockgen.go +++ b/logging/mockgen.go @@ -1,4 +1,4 @@ package logging -//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_connection_tracer_test.go github.com/refraction-networking/uquic/logging ConnectionTracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_tracer_test.go github.com/refraction-networking/uquic/logging Tracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_connection_tracer_test.go github.com/refraction-networking/uquic/logging ConnectionTracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package logging -self_package github.com/refraction-networking/uquic/logging -destination mock_tracer_test.go github.com/refraction-networking/uquic/logging Tracer" diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index ab70a522..4ebc2d42 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_batch_conn_test.go b/mock_batch_conn_test.go index fcb23e34..9621e7b4 100644 --- a/mock_batch_conn_test.go +++ b/mock_batch_conn_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ipv4 "golang.org/x/net/ipv4" ) diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index ea9e8f56..d857fd58 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index 9bc0e86f..eb2e90fa 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" handshake "github.com/refraction-networking/uquic/internal/handshake" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 737bc08e..002b1d17 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 622cb78b..7319739c 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go index 3ebe8f4e..a1c15fb1 100644 --- a/mock_mtu_discoverer_test.go +++ b/mock_mtu_discoverer_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_packer_test.go b/mock_packer_test.go index fe966fb9..409a7a8c 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 7b5c9a96..b9e28eeb 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 823d2358..8402d5e4 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_packetconn_test.go b/mock_packetconn_test.go index d6731e4a..c8e20bf2 100644 --- a/mock_packetconn_test.go +++ b/mock_packetconn_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockPacketConn is a mock of PacketConn interface. diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 30a29199..7ce24333 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -9,7 +9,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" ) diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go new file mode 100644 index 00000000..0a1a0f3a --- /dev/null +++ b/mock_raw_conn_test.go @@ -0,0 +1,122 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go (interfaces: RawConn) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" +) + +// MockRawConn is a mock of RawConn interface. +type MockRawConn struct { + ctrl *gomock.Controller + recorder *MockRawConnMockRecorder +} + +// MockRawConnMockRecorder is the mock recorder for MockRawConn. +type MockRawConnMockRecorder struct { + mock *MockRawConn +} + +// NewMockRawConn creates a new mock instance. +func NewMockRawConn(ctrl *gomock.Controller) *MockRawConn { + mock := &MockRawConn{ctrl: ctrl} + mock.recorder = &MockRawConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRawConn) EXPECT() *MockRawConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRawConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRawConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRawConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockRawConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockRawConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockRawConn)(nil).LocalAddr)) +} + +// ReadPacket mocks base method. +func (m *MockRawConn) ReadPacket() (receivedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadPacket") + ret0, _ := ret[0].(receivedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadPacket indicates an expected call of ReadPacket. +func (mr *MockRawConnMockRecorder) ReadPacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadPacket", reflect.TypeOf((*MockRawConn)(nil).ReadPacket)) +} + +// SetReadDeadline mocks base method. +func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) +} + +// WritePacket mocks base method. +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WritePacket indicates an expected call of WritePacket. +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2) +} + +// capabilities mocks base method. +func (m *MockRawConn) capabilities() connCapabilities { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "capabilities") + ret0, _ := ret[0].(connCapabilities) + return ret0 +} + +// capabilities indicates an expected call of capabilities. +func (mr *MockRawConnMockRecorder) capabilities() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockRawConn)(nil).capabilities)) +} diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index 72f9757e..6237feef 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" wire "github.com/refraction-networking/uquic/internal/wire" diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index a8750905..1b9c9214 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" handshake "github.com/refraction-networking/uquic/internal/handshake" ) diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index e63df48f..f55feaeb 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -8,7 +8,7 @@ import ( net "net" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index 2332e5ef..46bf7565 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" diff --git a/mock_sender_test.go b/mock_sender_test.go index bb3fa128..4c989937 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go index 6f859057..3ed36515 100644 --- a/mock_stream_getter_test.go +++ b/mock_stream_getter_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" ) diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 7acb402c..369f43be 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ackhandler "github.com/refraction-networking/uquic/internal/ackhandler" protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 53294be4..20e69a83 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -8,7 +8,7 @@ import ( context "context" reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index c23ad3a9..2c060856 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mock_token_store_test.go b/mock_token_store_test.go index 72d85962..c5d28643 100644 --- a/mock_token_store_test.go +++ b/mock_token_store_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockTokenStore is a mock of TokenStore interface. diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go index 09c11569..fa702529 100644 --- a/mock_unknown_packet_handler_test.go +++ b/mock_unknown_packet_handler_test.go @@ -7,7 +7,7 @@ package quic import ( reflect "reflect" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockUnknownPacketHandler is a mock of UnknownPacketHandler interface. diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index a0d3dc60..89dbf9c8 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" time "time" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" protocol "github.com/refraction-networking/uquic/internal/protocol" wire "github.com/refraction-networking/uquic/internal/wire" ) diff --git a/mockgen.go b/mockgen.go index de20437a..5adfbf59 100644 --- a/mockgen.go +++ b/mockgen.go @@ -2,73 +2,76 @@ package quic -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn" type SendConn = sendConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_raw_conn_test.go github.com/refraction-networking/uquic RawConn" +type RawConn = rawConn + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender" type Sender = sender -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI" type StreamI = streamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream" type CryptoStream = cryptoStream -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI" type ReceiveStreamI = receiveStreamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI" type SendStreamI = sendStreamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter" type StreamGetter = streamGetter -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender" type StreamSender = streamSender -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler" type CryptoDataHandler = cryptoDataHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource" type FrameSource = frameSource -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource" type AckFrameSource = ackFrameSource -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager" type StreamManager = streamManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager" type SealingManager = sealingManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker" type Unpacker = unpacker -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer" type Packer = packer -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer" type MTUDiscoverer = mtuDiscoverer -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner" type ConnRunner = connRunner -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn" type QUICConn = quicConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unknown_packet_handler_test.go github.com/refraction-networking/uquic UnknownPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unknown_packet_handler_test.go github.com/refraction-networking/uquic UnknownPacketHandler" type UnknownPacketHandler = unknownPacketHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager" type PacketHandlerManager = packetHandlerManager // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn" diff --git a/packet_handler_map.go b/packet_handler_map.go index a99a8562..41287c2a 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -26,9 +26,9 @@ type connCapabilities struct { // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (receivedPacket, error) - // The size parameter is used for GSO. - // If GSO is not support, len(b) must be equal to size. - WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) + // WritePacket writes a packet on the wire. + // If GSO is enabled, it's the caller's responsibility to set the correct control message. + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/packet_packer_test.go b/packet_packer_test.go index abf59f32..74e05a67 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "fmt" "net" "time" @@ -17,10 +18,9 @@ import ( "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Packet packer", func() { @@ -334,7 +334,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - quicErr := qerr.NewLocalCryptoError(0x42, "crypto error") + quicErr := qerr.NewLocalCryptoError(0x42, errors.New("crypto error")) quicErr.FrameType = 0x1234 p, err := packer.PackConnectionClose(quicErr, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 8ca1c8b4..9584df4c 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -10,9 +10,9 @@ import ( "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Packet Unpacker", func() { diff --git a/quic_suite_test.go b/quic_suite_test.go index d979d81b..954ca60b 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -9,9 +9,9 @@ import ( "sync" "testing" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func TestQuicGo(t *testing.T) { diff --git a/receive_stream_test.go b/receive_stream_test.go index 9cdbbd72..23396e1f 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -12,10 +12,10 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" + "go.uber.org/mock/gomock" ) var _ = Describe("Receive Stream", func() { diff --git a/send_conn.go b/send_conn.go index 63636e78..a3feaf62 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,10 +1,12 @@ package quic import ( + "fmt" "math" "net" "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/utils" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. @@ -20,61 +22,86 @@ type sendConn interface { type sconn struct { rawConn + localAddr net.Addr remoteAddr net.Addr - info packetInfo - oob []byte + + logger utils.Logger + + info packetInfo + oob []byte + // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. + gotGSOError bool } var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr) *sconn { - sc := &sconn{ - rawConn: c, - remoteAddr: remote, +func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logger) *sconn { + localAddr := c.LocalAddr() + if info.addr.IsValid() { + if udpAddr, ok := localAddr.(*net.UDPAddr); ok { + addrCopy := *udpAddr + addrCopy.IP = info.addr.AsSlice() + localAddr = &addrCopy + } } - if c.capabilities().GSO { - // add 32 bytes, so we can add the UDP_SEGMENT msg - sc.oob = make([]byte, 0, 32) - } - return sc -} -func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn { oob := info.OOB() - if c.capabilities().GSO { - // add 32 bytes, so we can add the UDP_SEGMENT msg - l := len(oob) - oob = append(oob, make([]byte, 32)...) - oob = oob[:l] - } + // add 32 bytes, so we can add the UDP_SEGMENT msg + l := len(oob) + oob = append(oob, make([]byte, 32)...) + oob = oob[:l] return &sconn{ rawConn: c, + localAddr: localAddr, remoteAddr: remote, info: info, oob: oob, + logger: logger, } } func (c *sconn) Write(p []byte, size protocol.ByteCount) error { + if !c.capabilities().GSO { + if protocol.ByteCount(len(p)) != size { + panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) + } + _, err := c.WritePacket(p, c.remoteAddr, c.oob) + return err + } + // GSO is supported. Append the control message and send. if size > math.MaxUint16 { panic("size overflow") } - _, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob) + _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) + if err != nil && isGSOError(err) { + // disable GSO for future calls + c.gotGSOError = true + if c.logger.Debug() { + c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr) + } + // send out the packets one by one + for len(p) > 0 { + l := len(p) + if l > int(size) { + l = int(size) + } + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + return err + } + p = p[l:] + } + return nil + } return err } -func (c *sconn) RemoteAddr() net.Addr { - return c.remoteAddr +func (c *sconn) capabilities() connCapabilities { + capabilities := c.rawConn.capabilities() + if capabilities.GSO { + capabilities.GSO = !c.gotGSOError + } + return capabilities } -func (c *sconn) LocalAddr() net.Addr { - addr := c.rawConn.LocalAddr() - if c.info.addr.IsValid() { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - addrCopy := *udpAddr - addrCopy.IP = c.info.addr.AsSlice() - addr = &addrCopy - } - } - return addr -} +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *sconn) LocalAddr() net.Addr { return c.localAddr } diff --git a/send_conn_test.go b/send_conn_test.go index 56fe9236..7f072430 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -2,46 +2,81 @@ package quic import ( "net" + "net/netip" + + "github.com/quic-go/quic-go/internal/utils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) +// Only if appendUDPSegmentSizeMsg actually appends a message (and isn't only a stub implementation), +// GSO is actually supported on this platform. +var platformSupportsGSO = len(appendUDPSegmentSizeMsg([]byte{}, 1337)) > 0 + var _ = Describe("Connection (for sending packets)", func() { - var ( - c sendConn - packetConn *MockPacketConn - addr net.Addr - ) + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - BeforeEach(func() { - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) - rawConn, err := wrapConn(packetConn) - Expect(err).ToNot(HaveOccurred()) - c = newSendConnWithPacketInfo(rawConn, addr, packetInfo{}) - }) - - It("writes", func() { - packetConn.EXPECT().WriteTo([]byte("foobar"), addr) - Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) - }) - - It("gets the remote address", func() { + It("gets the local and remote addresses", func() { + localAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1234} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr().Return(localAddr) + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + Expect(c.LocalAddr().String()).To(Equal("192.168.0.1:1234")) Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) }) - It("gets the local address", func() { - addr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 0, 1), - Port: 1234, - } - packetConn.EXPECT().LocalAddr().Return(addr) - Expect(c.LocalAddr()).To(Equal(addr)) + It("uses the local address from the packet info", func() { + localAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1234} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr().Return(localAddr) + c := newSendConn(rawConn, remoteAddr, packetInfo{addr: netip.AddrFrom4([4]byte{127, 0, 0, 42})}, utils.DefaultLogger) + Expect(c.LocalAddr().String()).To(Equal("127.0.0.42:1234")) }) - It("closes", func() { - packetConn.EXPECT().Close() - Expect(c.Close()).To(Succeed()) - }) + if platformSupportsGSO { + It("writes with GSO", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).Do(func(_ []byte, _ net.Addr, oob []byte) { + msg := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oob).To(Equal(msg)) + }) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + }) + + It("disables GSO if writing fails", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + Expect(c.capabilities().GSO).To(BeTrue()) + gomock.InOrder( + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).DoAndReturn(func(_ []byte, _ net.Addr, oob []byte) (int, error) { + msg := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oob).To(Equal(msg)) + return 0, errGSO + }), + rawConn.EXPECT().WritePacket([]byte("foo"), remoteAddr, gomock.Len(0)).Return(3, nil), + rawConn.EXPECT().WritePacket([]byte("bar"), remoteAddr, gomock.Len(0)).Return(3, nil), + ) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + Expect(c.capabilities().GSO).To(BeFalse()) // GSO support is now disabled + // make sure we actually enforce that + Expect(func() { c.Write([]byte("foobar"), 3) }).To(PanicWith("inconsistent packet size (6 vs 3)")) + }) + } else { + It("writes without GSO", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Len(0)) + Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) + }) + } }) diff --git a/send_queue_test.go b/send_queue_test.go index 1bc260b1..6abe5612 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -5,9 +5,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Send Queue", func() { diff --git a/send_stream_test.go b/send_stream_test.go index 0a9f7499..ab5589f9 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -11,7 +11,6 @@ import ( "golang.org/x/exp/rand" - "github.com/golang/mock/gomock" "github.com/refraction-networking/uquic/internal/ackhandler" "github.com/refraction-networking/uquic/internal/mocks" "github.com/refraction-networking/uquic/internal/protocol" @@ -20,6 +19,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/onsi/gomega/gbytes" + "go.uber.org/mock/gomock" ) var _ = Describe("Send Stream", func() { diff --git a/server.go b/server.go index 1815be3f..5b80ec29 100644 --- a/server.go +++ b/server.go @@ -633,7 +633,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) } conn = s.newConn( - newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info), + newSendConn(s.conn, p.remoteAddr, p.info, s.logger), s.connHandler, origDestConnID, retrySrcConnID, @@ -743,7 +743,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) return err } @@ -842,7 +842,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) return err } @@ -880,7 +880,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/server_test.go b/server_test.go index c3064016..d1563ad3 100644 --- a/server_test.go +++ b/server_test.go @@ -21,9 +21,9 @@ import ( "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Server", func() { diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index 3928997b..897e0883 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -10,9 +10,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type mockGenericStream struct { diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index 4ab03581..43d9d3db 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -13,9 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Streams Map (outgoing)", func() { diff --git a/streams_map_test.go b/streams_map_test.go index 1afc1db9..0fceffd6 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -6,8 +6,6 @@ import ( "fmt" "net" - "github.com/golang/mock/gomock" - "github.com/refraction-networking/uquic/internal/flowcontrol" "github.com/refraction-networking/uquic/internal/mocks" "github.com/refraction-networking/uquic/internal/protocol" @@ -16,6 +14,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) func (e streamError) TestError() error { diff --git a/sys_conn.go b/sys_conn.go index d1ba2bbd..88e098fd 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -1,7 +1,6 @@ package quic import ( - "fmt" "log" "net" "os" @@ -105,10 +104,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) { - if uint16(len(b)) != packetSize { - panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) - } +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_df_linux.go b/sys_conn_df_linux.go index 4a988b33..5c8da7d0 100644 --- a/sys_conn_df_linux.go +++ b/sys_conn_df_linux.go @@ -4,11 +4,7 @@ package quic import ( "errors" - "log" - "os" - "strconv" "syscall" - "unsafe" "golang.org/x/sys/unix" @@ -38,43 +34,9 @@ func setDF(rawConn syscall.RawConn) (bool, error) { return true, nil } -func maybeSetGSO(rawConn syscall.RawConn) bool { - enable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_ENABLE_GSO")) - if !enable { - return false - } - - var setErr error - if err := rawConn.Control(func(fd uintptr) { - setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, unix.UDP_SEGMENT, 1) - }); err != nil { - setErr = err - } - if setErr != nil { - log.Println("failed to enable GSO") - return false - } - return true -} - func isSendMsgSizeErr(err error) bool { // https://man7.org/linux/man-pages/man7/udp.7.html return errors.Is(err, unix.EMSGSIZE) } -func isRecvMsgSizeErr(err error) bool { return false } - -func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { - startLen := len(b) - const dataLen = 2 // payload is a uint16 - b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) - h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) - h.Level = syscall.IPPROTO_UDP - h.Type = unix.UDP_SEGMENT - h.SetLen(unix.CmsgLen(dataLen)) - - // UnixRights uses the private `data` method, but I *think* this achieves the same goal. - offset := startLen + unix.CmsgSpace(0) - *(*uint16)(unsafe.Pointer(&b[offset])) = size - return b -} +func isRecvMsgSizeErr(error) bool { return false } diff --git a/sys_conn_helper_darwin.go b/sys_conn_helper_darwin.go index bf735f0f..758cf778 100644 --- a/sys_conn_helper_darwin.go +++ b/sys_conn_helper_darwin.go @@ -5,6 +5,7 @@ package quic import ( "encoding/binary" "net/netip" + "syscall" "golang.org/x/sys/unix" ) @@ -29,3 +30,5 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true } + +func isGSOSupported(syscall.RawConn) bool { return false } diff --git a/sys_conn_helper_freebsd.go b/sys_conn_helper_freebsd.go index fe5a7c20..a2baae3b 100644 --- a/sys_conn_helper_freebsd.go +++ b/sys_conn_helper_freebsd.go @@ -4,6 +4,7 @@ package quic import ( "net/netip" + "syscall" "golang.org/x/sys/unix" ) @@ -24,3 +25,5 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body)), 0, true } + +func isGSOSupported(syscall.RawConn) bool { return false } diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index 61224eaa..6a049241 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -4,8 +4,12 @@ package quic import ( "encoding/binary" + "errors" "net/netip" + "os" + "strconv" "syscall" + "unsafe" "golang.org/x/sys/unix" ) @@ -48,3 +52,46 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) { } return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true } + +// isGSOSupported tests if the kernel supports GSO. +// Sending with GSO might still fail later on, if the interface doesn't support it (see isGSOError). +func isGSOSupported(conn syscall.RawConn) bool { + disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_GSO")) + if err == nil && disabled { + return false + } + var serr error + if err := conn.Control(func(fd uintptr) { + _, serr = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + }); err != nil { + return false + } + return serr == nil +} + +func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { + startLen := len(b) + const dataLen = 2 // payload is a uint16 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_UDP + h.Type = unix.UDP_SEGMENT + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + *(*uint16)(unsafe.Pointer(&b[offset])) = size + return b +} + +func isGSOError(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have tx checksums enabled, + // which is a hard requirement of UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/sys_conn_helper_linux_test.go b/sys_conn_helper_linux_test.go index fa39c523..4cf59abe 100644 --- a/sys_conn_helper_linux_test.go +++ b/sys_conn_helper_linux_test.go @@ -1,22 +1,24 @@ -// We need root permissions to use RCVBUFFORCE. -// This test is therefore only compiled when the root build flag is set. -// It can only succeed if the tests are then also run with root permissions. -//go:build linux && root +//go:build linux package quic import ( + "errors" "net" "os" + "golang.org/x/sys/unix" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +var errGSO = &os.SyscallError{Err: unix.EIO} + var _ = Describe("forcing a change of send and receive buffer sizes", func() { It("forces a change of the receive buffer size", func() { if os.Getuid() != 0 { - Fail("Must be root to force change the receive buffer size") + Skip("Must be root to force change the receive buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -43,7 +45,7 @@ var _ = Describe("forcing a change of send and receive buffer sizes", func() { It("forces a change of the send buffer size", func() { if os.Getuid() != 0 { - Fail("Must be root to force change the send buffer size") + Skip("Must be root to force change the send buffer size") } c, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -67,4 +69,10 @@ var _ = Describe("forcing a change of send and receive buffer sizes", func() { // The kernel doubles this value (to allow space for bookkeeping overhead) Expect(size).To(Equal(2 * large)) }) + + It("detects GSO errors", func() { + Expect(isGSOError(errGSO)).To(BeTrue()) + Expect(isGSOError(nil)).To(BeFalse()) + Expect(isGSOError(errors.New("test"))).To(BeFalse()) + }) }) diff --git a/sys_conn_helper_nonlinux.go b/sys_conn_helper_nonlinux.go index 80b795c3..cace82d5 100644 --- a/sys_conn_helper_nonlinux.go +++ b/sys_conn_helper_nonlinux.go @@ -4,3 +4,6 @@ package quic func forceSetReceiveBuffer(c any, bytes int) error { return nil } func forceSetSendBuffer(c any, bytes int) error { return nil } + +func appendUDPSegmentSizeMsg([]byte, uint16) []byte { return nil } +func isGSOError(error) bool { return false } diff --git a/sys_conn_helper_nonlinux_test.go b/sys_conn_helper_nonlinux_test.go new file mode 100644 index 00000000..29d42ad3 --- /dev/null +++ b/sys_conn_helper_nonlinux_test.go @@ -0,0 +1,7 @@ +//go:build !linux + +package quic + +import "errors" + +var errGSO = errors.New("fake GSO error") diff --git a/sys_conn_no_gso.go b/sys_conn_no_gso.go deleted file mode 100644 index 6f6a8c91..00000000 --- a/sys_conn_no_gso.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build darwin || freebsd - -package quic - -import "syscall" - -func maybeSetGSO(_ syscall.RawConn) bool { return false } -func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 140bb2f4..66d5ce67 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -5,7 +5,6 @@ package quic import ( "encoding/binary" "errors" - "fmt" "log" "net" "net/netip" @@ -128,10 +127,6 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { bc = ipv4.NewPacketConn(c) } - // Try enabling GSO. - // This will only succeed on Linux, and only for kernels > 4.18. - supportsGSO := maybeSetGSO(rawConn) - msgs := make([]ipv4.Message, batchSize) for i := range msgs { // preallocate the [][]byte @@ -142,9 +137,11 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { batchConn: bc, messages: msgs, readPos: batchSize, + cap: connCapabilities{ + DF: supportsDF, + GSO: isGSOSupported(rawConn), + }, } - oobConn.cap.DF = supportsDF - oobConn.cap.GSO = supportsGSO for i := 0; i < batchSize; i++ { oobConn.messages[i].OOB = make([]byte, oobBufferSize) } @@ -231,17 +228,9 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO (and we activated GSO support before), -// it appends the UDP_SEGMENT size message to oob. -// Callers are advised to make sure that oob has a sufficient capacity, -// such that appending the UDP_SEGMENT size message doesn't cause an allocation. -func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) { - if c.cap.GSO { - oob = appendUDPSegmentSizeMsg(oob, packetSize) - } else if uint16(len(b)) != packetSize { - panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) - } - n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) +// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. +func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { + n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index dce9d0b5..0a4efb94 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -13,9 +13,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("OOB Conn Test", func() { diff --git a/sys_conn_test.go b/sys_conn_test.go index 41269a0f..0af911da 100644 --- a/sys_conn_test.go +++ b/sys_conn_test.go @@ -6,10 +6,9 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" - "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Basic Conn Test", func() { diff --git a/tools.go b/tools.go index e848317f..d00ce748 100644 --- a/tools.go +++ b/tools.go @@ -3,6 +3,6 @@ package quic import ( - _ "github.com/golang/mock/mockgen" _ "github.com/onsi/ginkgo/v2/ginkgo" + _ "go.uber.org/mock/mockgen" ) diff --git a/transport.go b/transport.go index ff4cf144..a0d0784a 100644 --- a/transport.go +++ b/transport.go @@ -6,14 +6,14 @@ import ( "errors" "net" "sync" + "sync/atomic" "time" tls "github.com/refraction-networking/utls" - "github.com/refraction-networking/uquic/internal/wire" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" ) @@ -86,6 +86,9 @@ type Transport struct { createdConn bool isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + readingNonQUICPackets atomic.Bool + nonQUICPackets chan receivedPacket + logger utils.Logger } @@ -149,26 +152,15 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen // Dial dials a new connection to a remote host (not using 0-RTT). func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { - if err := validateConfig(conf); err != nil { - return nil, err - } - conf = populateConfig(conf) - - if err := t.init(t.isSingleUse); err != nil { - return nil, err - } - var onClose func() - if t.isSingleUse { - onClose = func() { t.Close() } - } - tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 - - return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) + return t.dial(ctx, addr, "", tlsConf, conf, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + return t.dial(ctx, addr, "", tlsConf, conf, true) +} + +func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { if err := validateConfig(conf); err != nil { return nil, err } @@ -183,8 +175,8 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - - return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) + setTLSConfigServerName(tlsConf, addr, host) + return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) } func (t *Transport) init(allowZeroLengthConnIDs bool) error { @@ -200,7 +192,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { return } } - t.conn = conn t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn @@ -234,7 +225,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, uint16(len(b)), addr, nil) + return t.conn.WritePacket(b, addr, nil) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -252,7 +243,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -347,6 +338,13 @@ func (t *Transport) listen(conn rawConn) { } func (t *Transport) handlePacket(p receivedPacket) { + if len(p.data) == 0 { + return + } + if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) { + t.handleNonQUICPacket(p) + return + } connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) @@ -413,7 +411,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } @@ -435,3 +433,61 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool { } return false } + +func (t *Transport) handleNonQUICPacket(p receivedPacket) { + // Strictly speaking, this is racy, + // but we only care about receiving packets at some point after ReadNonQUICPacket has been called. + if !t.readingNonQUICPackets.Load() { + return + } + select { + case t.nonQUICPackets <- p: + default: + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +const maxQueuedNonQUICPackets = 32 + +// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection. +// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0. +// Note that this is stricter than the detection logic defined in RFC 9443. +func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { + if err := t.init(false); err != nil { + return 0, nil, err + } + if !t.readingNonQUICPackets.Load() { + t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets) + t.readingNonQUICPackets.Store(true) + } + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case p := <-t.nonQUICPackets: + n := copy(b, p.data) + return n, p.remoteAddr, nil + case <-t.listening: + return 0, nil, errors.New("closed") + } +} + +func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) { + // If no ServerName is set, infer the ServerName from the host we're connecting to. + if tlsConf.ServerName != "" { + return + } + if host == "" { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + tlsConf.ServerName = udpAddr.IP.String() + return + } + } + h, _, err := net.SplitHostPort(host) + if err != nil { // This happens if the host doesn't contain a port number. + tlsConf.ServerName = host + return + } + tlsConf.ServerName = h +} diff --git a/transport_test.go b/transport_test.go index 025e5322..58856218 100644 --- a/transport_test.go +++ b/transport_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/rand" "errors" "net" @@ -15,9 +16,9 @@ import ( "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) var _ = Describe("Transport", func() { @@ -123,7 +124,7 @@ var _ = Describe("Transport", func() { tr.Close() }) - It("drops unparseable packets", func() { + It("drops unparseable QUIC packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) tracer := mocklogging.NewMockTracer(mockCtrl) @@ -137,7 +138,7 @@ var _ = Describe("Transport", func() { tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ addr: addr, - data: []byte{0, 1, 2, 3}, + data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3}, } Eventually(dropped).Should(BeClosed()) @@ -324,6 +325,90 @@ var _ = Describe("Transport", func() { conns := getMultiplexer().(*connMultiplexer).conns Expect(len(conns)).To(BeZero()) }) + + It("allows receiving non-QUIC packets", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, + Tracer: tracer, + } + tr.init(true) + receivedPacketChan := make(chan []byte) + go func() { + defer GinkgoRecover() + b := make([]byte, 100) + n, addr, err := tr.ReadNonQUICPacket(context.Background(), b) + Expect(err).ToNot(HaveOccurred()) + Expect(addr).To(Equal(remoteAddr)) + receivedPacketChan <- b[:n] + }() + // Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called. + // Give the Go routine some time to spin up. + time.Sleep(scaleDuration(50 * time.Millisecond)) + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + + Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3}))) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, + Tracer: tracer, + } + tr.init(true) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10)) + Expect(err).To(MatchError(context.Canceled)) + + for i := 0; i < maxQueuedNonQUICPackets; i++ { + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + } + + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(done) + }) + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + Eventually(done).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) + + remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234} + DescribeTable("setting the tls.Config.ServerName", + func(expected string, conf *tls.Config, addr net.Addr, host string) { + setTLSConfigServerName(conf, addr, host) + Expect(conf.ServerName).To(Equal(expected)) + }, + Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"), + Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"), + Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"), + Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""), + ) }) type mockSyscallConn struct {