Make the xTransport functions return the HTTP body directly

This simplifies things, but also make RTT computation way more reliable
This commit is contained in:
Frank Denis 2020-02-21 22:33:34 +01:00
parent a6d946c41f
commit aa0e7f42d3
4 changed files with 24 additions and 27 deletions

View file

@ -3,8 +3,6 @@ package main
import ( import (
crypto_rand "crypto/rand" crypto_rand "crypto/rand"
"encoding/binary" "encoding/binary"
"io"
"io/ioutil"
"net" "net"
"os" "os"
"sync/atomic" "sync/atomic"
@ -513,9 +511,9 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
tid := TransactionID(query) tid := TransactionID(query)
SetTransactionID(query, 0) SetTransactionID(query, 0)
serverInfo.noticeBegin(proxy) serverInfo.noticeBegin(proxy)
resp, _, err := proxy.xTransport.DoHQuery(serverInfo.useGet, serverInfo.URL, query, proxy.timeout) serverResponse, tls, _, err := proxy.xTransport.DoHQuery(serverInfo.useGet, serverInfo.URL, query, proxy.timeout)
SetTransactionID(query, tid) SetTransactionID(query, tid)
if err == nil { if err == nil || tls == nil || !tls.HandshakeComplete {
response = nil response = nil
} else if stale, ok := pluginsState.sessionData["stale"]; ok { } else if stale, ok := pluginsState.sessionData["stale"]; ok {
dlog.Debug("Serving stale response") dlog.Debug("Serving stale response")
@ -528,7 +526,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
return return
} }
if response == nil { if response == nil {
response, err = ioutil.ReadAll(io.LimitReader(resp.Body, int64(MaxDNSPacketSize))) response = serverResponse
} }
if err != nil { if err != nil {
pluginsState.returnCode = PluginsReturnCodeNetworkError pluginsState.returnCode = PluginsReturnCodeNetworkError

View file

@ -6,8 +6,6 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"net/url" "net/url"
@ -381,18 +379,17 @@ func fetchDoHServerInfo(proxy *Proxy, name string, stamp stamps.ServerStamp, isN
} }
body := dohTestPacket(0xcafe) body := dohTestPacket(0xcafe)
useGet := false useGet := false
if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { if _, _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
useGet = true useGet = true
if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { if _, _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
return ServerInfo{}, err return ServerInfo{}, err
} }
dlog.Debugf("Server [%s] doesn't appear to support POST; falling back to GET requests", name) dlog.Debugf("Server [%s] doesn't appear to support POST; falling back to GET requests", name)
} }
resp, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout) serverResponse, tls, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout)
if err != nil { if err != nil {
return ServerInfo{}, err return ServerInfo{}, err
} }
tls := resp.TLS
if tls == nil || !tls.HandshakeComplete { if tls == nil || !tls.HandshakeComplete {
return ServerInfo{}, errors.New("TLS handshake failed") return ServerInfo{}, errors.New("TLS handshake failed")
} }
@ -428,7 +425,7 @@ func fetchDoHServerInfo(proxy *Proxy, name string, stamp stamps.ServerStamp, isN
if !found && len(stamp.Hashes) > 0 { if !found && len(stamp.Hashes) > 0 {
return ServerInfo{}, fmt.Errorf("Certificate hash [%x] not found for [%s]", wantedHash, name) return ServerInfo{}, fmt.Errorf("Certificate hash [%x] not found for [%s]", wantedHash, name)
} }
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength)) respBody := serverResponse
if err != nil { if err != nil {
return ServerInfo{}, err return ServerInfo{}, err
} }

View file

@ -3,9 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
@ -132,12 +130,8 @@ func (source *Source) parseURLs(urls []string) {
} }
func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) {
var resp *http.Response bin, _, _, err = xTransport.Get(u, "", DefaultTimeout)
if resp, _, err = xTransport.Get(u, "", DefaultTimeout); err == nil { return bin, err
bin, err = ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength))
resp.Body.Close()
}
return
} }
func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) { func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) {

View file

@ -8,6 +8,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net" "net"
@ -316,7 +317,7 @@ func (xTransport *XTransport) resolveAndUpdateCache(host string) error {
return nil return nil
} }
func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) (*http.Response, time.Duration, error) { func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
if timeout <= 0 { if timeout <= 0 {
timeout = xTransport.timeout timeout = xTransport.timeout
} }
@ -338,11 +339,11 @@ func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string,
} }
host, _ := ExtractHostAndPort(url.Host, 0) host, _ := ExtractHostAndPort(url.Host, 0)
if xTransport.proxyDialer == nil && strings.HasSuffix(host, ".onion") { if xTransport.proxyDialer == nil && strings.HasSuffix(host, ".onion") {
return nil, 0, errors.New("Onion service is not reachable without Tor") return nil, nil, 0, errors.New("Onion service is not reachable without Tor")
} }
if err := xTransport.resolveAndUpdateCache(host); err != nil { if err := xTransport.resolveAndUpdateCache(host); err != nil {
dlog.Errorf("Unable to resolve [%v] - Make sure that the system resolver works, or that `fallback_resolver` has been set to a resolver that can be reached", host) dlog.Errorf("Unable to resolve [%v] - Make sure that the system resolver works, or that `fallback_resolver` has been set to a resolver that can be reached", host)
return nil, 0, err return nil, nil, 0, err
} }
req := &http.Request{ req := &http.Request{
Method: method, Method: method,
@ -373,19 +374,26 @@ func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string,
xTransport.tlsCipherSuite = nil xTransport.tlsCipherSuite = nil
xTransport.rebuildTransport() xTransport.rebuildTransport()
} }
return nil, nil, 0, err
} }
return resp, rtt, err tls := resp.TLS
bin, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength))
if err != nil {
return nil, tls, 0, err
}
resp.Body.Close()
return bin, tls, rtt, err
} }
func (xTransport *XTransport) Get(url *url.URL, accept string, timeout time.Duration) (*http.Response, time.Duration, error) { func (xTransport *XTransport) Get(url *url.URL, accept string, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
return xTransport.Fetch("GET", url, accept, "", nil, timeout) return xTransport.Fetch("GET", url, accept, "", nil, timeout)
} }
func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) (*http.Response, time.Duration, error) { func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
return xTransport.Fetch("POST", url, accept, contentType, body, timeout) return xTransport.Fetch("POST", url, accept, contentType, body, timeout)
} }
func (xTransport *XTransport) DoHQuery(useGet bool, url *url.URL, body []byte, timeout time.Duration) (*http.Response, time.Duration, error) { func (xTransport *XTransport) DoHQuery(useGet bool, url *url.URL, body []byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
dataType := "application/dns-message" dataType := "application/dns-message"
if useGet { if useGet {
qs := url.Query() qs := url.Query()