uquic/http3/u_roundtrip.go
Gaukas Wang 95575f5fe7
break: update repo url [ci skip]
uTLS is not yet bumped to the new version, so this commit breaks the dependencies relationship by getting rid of the local replace.
2023-08-03 18:58:52 -06:00

192 lines
4.8 KiB
Go

package http3
import (
"context"
"errors"
"fmt"
"net"
"net/http"
quic "github.com/refraction-networking/uquic"
tls "github.com/refraction-networking/utls"
"golang.org/x/net/http/httpguts"
)
type URoundTripper struct {
*RoundTripper
quicSpec *quic.QUICSpec
uTransportOverride *quic.UTransport
}
func GetURoundTripper(r *RoundTripper, QUICSpec *quic.QUICSpec, uTransport *quic.UTransport) *URoundTripper {
QUICSpec.UpdateConfig(r.QuicConfig)
return &URoundTripper{
RoundTripper: r,
quicSpec: QUICSpec,
uTransportOverride: uTransport,
}
}
// RoundTripOpt is like RoundTrip, but takes options.
func (r *URoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("http3: nil Request.URL")
}
if req.URL.Scheme != "https" {
closeRequestBody(req)
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("http3: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("http3: nil Request.Header")
}
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
}
}
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
defer cl.useCount.Add(-1)
rsp, err := cl.RoundTripOpt(req, opt)
if err != nil {
r.removeClient(hostname)
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
}
}
}
return rsp, err
}
// RoundTrip does a round trip.
func (r *URoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *URoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]*roundTripCloserWithCount)
}
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, false, ErrNoCachedConn
}
var err error
newCl := newClient
if r.newClient != nil {
newCl = r.newClient
}
dial := r.Dial
if dial == nil {
if r.transport == nil && r.uTransportOverride == nil {
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, false, err
}
r.uTransportOverride = &quic.UTransport{
Transport: &quic.Transport{
Conn: udpConn,
},
QUICSpec: r.quicSpec,
}
}
dial = r.makeDialer()
}
c, err := newCl(
hostname,
r.TLSClientConfig,
&roundTripperOpts{
EnableDatagram: r.EnableDatagrams,
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
dial,
)
if err != nil {
return nil, false, err
}
client = &roundTripCloserWithCount{roundTripCloser: c}
r.clients[hostname] = client
} else if client.HandshakeComplete() {
isReused = true
}
client.useCount.Add(1)
return client, isReused, nil
}
func (r *URoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
if r.transport != nil {
if err := r.transport.Close(); err != nil {
return err
}
if err := r.transport.Conn.Close(); err != nil {
return err
}
r.transport = nil
}
if r.uTransportOverride != nil {
if err := r.uTransportOverride.Close(); err != nil {
return err
}
if err := r.uTransportOverride.Conn.Close(); err != nil {
return err
}
r.uTransportOverride = nil
}
return nil
}
// makeDialer makes a QUIC dialer using r.udpConn.
func (r *URoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
if r.uTransportOverride != nil {
return r.uTransportOverride.DialEarly(ctx, udpAddr, tlsCfg, cfg)
} else if r.transport == nil {
return nil, errors.New("http3: no QUIC transport available")
}
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}