diff --git a/dnscrypt-proxy/proxy.go b/dnscrypt-proxy/proxy.go index 5c878202..291f6749 100644 --- a/dnscrypt-proxy/proxy.go +++ b/dnscrypt-proxy/proxy.go @@ -3,8 +3,6 @@ package main import ( crypto_rand "crypto/rand" "encoding/binary" - "io" - "io/ioutil" "net" "os" "sync/atomic" @@ -513,9 +511,9 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str tid := TransactionID(query) SetTransactionID(query, 0) serverInfo.noticeBegin(proxy) - resp, _, err := proxy.xTransport.DoHQuery(serverInfo.useGet, serverInfo.URL, query, proxy.timeout) + serverResponse, tls, _, err := proxy.xTransport.DoHQuery(serverInfo.useGet, serverInfo.URL, query, proxy.timeout) SetTransactionID(query, tid) - if err == nil { + if err == nil || tls == nil || !tls.HandshakeComplete { response = nil } else if stale, ok := pluginsState.sessionData["stale"]; ok { dlog.Debug("Serving stale response") @@ -528,7 +526,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str return } if response == nil { - response, err = ioutil.ReadAll(io.LimitReader(resp.Body, int64(MaxDNSPacketSize))) + response = serverResponse } if err != nil { pluginsState.returnCode = PluginsReturnCodeNetworkError diff --git a/dnscrypt-proxy/serversInfo.go b/dnscrypt-proxy/serversInfo.go index 264a944d..e8c38c4c 100644 --- a/dnscrypt-proxy/serversInfo.go +++ b/dnscrypt-proxy/serversInfo.go @@ -6,8 +6,6 @@ import ( "encoding/hex" "errors" "fmt" - "io" - "io/ioutil" "math/rand" "net" "net/url" @@ -381,18 +379,17 @@ func fetchDoHServerInfo(proxy *Proxy, name string, stamp stamps.ServerStamp, isN } body := dohTestPacket(0xcafe) useGet := false - if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { + if _, _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { useGet = true - if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { + if _, _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { return ServerInfo{}, err } dlog.Debugf("Server [%s] doesn't appear to support POST; falling back to GET requests", name) } - resp, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout) + serverResponse, tls, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout) if err != nil { return ServerInfo{}, err } - tls := resp.TLS if tls == nil || !tls.HandshakeComplete { return ServerInfo{}, errors.New("TLS handshake failed") } @@ -428,7 +425,7 @@ func fetchDoHServerInfo(proxy *Proxy, name string, stamp stamps.ServerStamp, isN if !found && len(stamp.Hashes) > 0 { return ServerInfo{}, fmt.Errorf("Certificate hash [%x] not found for [%s]", wantedHash, name) } - respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength)) + respBody := serverResponse if err != nil { return ServerInfo{}, err } diff --git a/dnscrypt-proxy/sources.go b/dnscrypt-proxy/sources.go index b36e611a..2f8b2b72 100644 --- a/dnscrypt-proxy/sources.go +++ b/dnscrypt-proxy/sources.go @@ -3,9 +3,7 @@ package main import ( "bytes" "fmt" - "io" "io/ioutil" - "net/http" "net/url" "os" "path/filepath" @@ -132,12 +130,8 @@ func (source *Source) parseURLs(urls []string) { } func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { - var resp *http.Response - if resp, _, err = xTransport.Get(u, "", DefaultTimeout); err == nil { - bin, err = ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength)) - resp.Body.Close() - } - return + bin, _, _, err = xTransport.Get(u, "", DefaultTimeout) + return bin, err } func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) { diff --git a/dnscrypt-proxy/xtransport.go b/dnscrypt-proxy/xtransport.go index 4ffd0db2..927ab9d5 100644 --- a/dnscrypt-proxy/xtransport.go +++ b/dnscrypt-proxy/xtransport.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "encoding/hex" "errors" + "io" "io/ioutil" "math/rand" "net" @@ -316,7 +317,7 @@ func (xTransport *XTransport) resolveAndUpdateCache(host string) error { return nil } -func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) (*http.Response, time.Duration, error) { +func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) { if timeout <= 0 { timeout = xTransport.timeout } @@ -338,11 +339,11 @@ func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, } host, _ := ExtractHostAndPort(url.Host, 0) if xTransport.proxyDialer == nil && strings.HasSuffix(host, ".onion") { - return nil, 0, errors.New("Onion service is not reachable without Tor") + return nil, nil, 0, errors.New("Onion service is not reachable without Tor") } if err := xTransport.resolveAndUpdateCache(host); err != nil { dlog.Errorf("Unable to resolve [%v] - Make sure that the system resolver works, or that `fallback_resolver` has been set to a resolver that can be reached", host) - return nil, 0, err + return nil, nil, 0, err } req := &http.Request{ Method: method, @@ -373,19 +374,26 @@ func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, xTransport.tlsCipherSuite = nil xTransport.rebuildTransport() } + return nil, nil, 0, err } - return resp, rtt, err + tls := resp.TLS + bin, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength)) + if err != nil { + return nil, tls, 0, err + } + resp.Body.Close() + return bin, tls, rtt, err } -func (xTransport *XTransport) Get(url *url.URL, accept string, timeout time.Duration) (*http.Response, time.Duration, error) { +func (xTransport *XTransport) Get(url *url.URL, accept string, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) { return xTransport.Fetch("GET", url, accept, "", nil, timeout) } -func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) (*http.Response, time.Duration, error) { +func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) { return xTransport.Fetch("POST", url, accept, contentType, body, timeout) } -func (xTransport *XTransport) DoHQuery(useGet bool, url *url.URL, body []byte, timeout time.Duration) (*http.Response, time.Duration, error) { +func (xTransport *XTransport) DoHQuery(useGet bool, url *url.URL, body []byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) { dataType := "application/dns-message" if useGet { qs := url.Query()