Merge remote-tracking branch 'origin/pull/819'

This commit is contained in:
Matthew Esposito 2023-12-26 15:48:27 -05:00
commit 90d1831352
No known key found for this signature in database
8 changed files with 402 additions and 45 deletions

View file

@ -1,4 +1,5 @@
use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
use hyper::client::HttpConnector;
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
@ -7,19 +8,28 @@ use libflate::gzip;
use once_cell::sync::Lazy;
use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
use std::{io, result::Result, sync::atomic::Ordering::SeqCst};
use crate::instance_info::INSTANCE_INFO;
use std::{io, result::Result};
use tokio::sync::RwLock;
use crate::dbg_msg;
use crate::oauth::{token_daemon, Oauth};
use crate::server::RequestExt;
use crate::{config, dbg_msg};
const REDDIT_URL_BASE: &str = "https://www.reddit.com";
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
pub(crate) static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build();
client::Client::builder().build(https)
});
pub(crate) static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
let client = block_on(Oauth::new());
tokio::spawn(token_daemon());
RwLock::new(client)
});
/// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`.
///
@ -135,14 +145,27 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// Construct the hyper client from the HTTPS connector.
let client: client::Client<_, hyper::Body> = CLIENT.clone();
let (token, vendor_id, device_id, user_agent, loid) = {
let client = block_on(OAUTH_CLIENT.read());
(
client.token.clone(),
client.headers_map.get("Client-Vendor-Id").cloned().unwrap_or_default(),
client.headers_map.get("X-Reddit-Device-Id").cloned().unwrap_or_default(),
client.headers_map.get("User-Agent").cloned().unwrap_or_default(),
client.headers_map.get("x-reddit-loid").cloned().unwrap_or_default(),
)
};
// Build request to Reddit. When making a GET, request gzip compression.
// (Reddit doesn't do brotli yet.)
let builder = Request::builder()
.method(method)
.uri(&url)
.header("User-Agent", concat!("web:libreddit:", env!("CARGO_PKG_VERSION")))
.header("Host", "www.reddit.com")
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
.header("User-Agent", user_agent)
.header("Client-Vendor-Id", vendor_id)
.header("X-Reddit-Device-Id", device_id)
.header("x-reddit-loid", loid)
.header("Host", "oauth.reddit.com")
.header("Authorization", &format!("Bearer {}", token))
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
.header("Accept-Language", "en-US,en;q=0.5")
.header("Connection", "keep-alive")