mirror of
https://github.com/redlib-org/redlib.git
synced 2025-04-04 13:37:40 +03:00
fix(client): use boolean flag
This commit is contained in:
parent
edb16f29ce
commit
9e2dc50d6e
3 changed files with 27 additions and 17 deletions
|
@ -15,6 +15,7 @@ use serde_json::Value;
|
|||
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU16};
|
||||
use std::sync::OnceLock;
|
||||
use std::{io, result::Result};
|
||||
|
||||
use crate::dbg_msg;
|
||||
|
@ -25,14 +26,28 @@ use crate::utils::format_url;
|
|||
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
|
||||
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
||||
|
||||
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
|
||||
pub static CLIENT: OnceLock<Client<HttpsConnector<HttpConnector>>> = OnceLock::new();
|
||||
|
||||
pub static OAUTH_CLIENT: Lazy<ArcSwap<Oauth>> = Lazy::new(|| {
|
||||
let client = block_on(Oauth::new());
|
||||
tokio::spawn(token_daemon());
|
||||
ArcSwap::new(client.into())
|
||||
});
|
||||
|
||||
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
||||
|
||||
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
/// Generate a client given a flag.
|
||||
#[allow(unused_variables)]
|
||||
pub fn generate_client(no_https_verification: bool) -> Client<HttpsConnector<HttpConnector>> {
|
||||
// Use native certificates to verify requests to reddit
|
||||
#[cfg(not(feature = "no-https-verification"))]
|
||||
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build();
|
||||
|
||||
// If https verification is disabled for debug purposes, create a custom ClientConfig
|
||||
#[cfg(feature = "no-https-verification")]
|
||||
let https = {
|
||||
let https = if no_https_verification {
|
||||
use rustls::ClientConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -60,19 +75,11 @@ pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
|
|||
|
||||
config.dangerous().set_certificate_verifier(Arc::new(NoCertificateVerification));
|
||||
hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(config).https_only().enable_http1().build()
|
||||
} else {
|
||||
hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build()
|
||||
};
|
||||
client::Client::builder().build(https)
|
||||
});
|
||||
|
||||
pub static OAUTH_CLIENT: Lazy<ArcSwap<Oauth>> = Lazy::new(|| {
|
||||
let client = block_on(Oauth::new());
|
||||
tokio::spawn(token_daemon());
|
||||
ArcSwap::new(client.into())
|
||||
});
|
||||
|
||||
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
||||
|
||||
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
||||
}
|
||||
|
||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||
|
@ -155,7 +162,7 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
|||
let parsed_uri = url.parse::<Uri>().map_err(|_| "Couldn't parse URL".to_string())?;
|
||||
|
||||
// Build the hyper client from the HTTPS connector.
|
||||
let client: Client<_, Body> = CLIENT.clone();
|
||||
let client: Client<_, Body> = CLIENT.get().unwrap().clone();
|
||||
|
||||
let mut builder = Request::get(parsed_uri);
|
||||
|
||||
|
@ -211,7 +218,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
|||
let url = format!("{REDDIT_URL_BASE}{path}");
|
||||
|
||||
// Construct the hyper client from the HTTPS connector.
|
||||
let client: Client<_, Body> = CLIENT.clone();
|
||||
let client: Client<_, Body> = CLIENT.get().unwrap().clone();
|
||||
|
||||
let (token, vendor_id, device_id, user_agent, loid) = {
|
||||
let client = OAUTH_CLIENT.load_full();
|
||||
|
|
|
@ -22,7 +22,7 @@ use futures_lite::FutureExt;
|
|||
use hyper::{header::HeaderValue, Body, Request, Response};
|
||||
|
||||
mod client;
|
||||
use client::{canonical_path, proxy};
|
||||
use client::{canonical_path, generate_client, proxy, CLIENT};
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use server::RequestExt;
|
||||
|
@ -179,6 +179,7 @@ async fn main() {
|
|||
let address = matches.get_one::<String>("address").unwrap();
|
||||
let port = matches.get_one::<String>("port").unwrap();
|
||||
let hsts = matches.get_one("hsts").map(|m: &String| m.as_str());
|
||||
let no_https_verification = matches.contains_id("no-https-verification");
|
||||
|
||||
let listener = [address, ":", port].concat();
|
||||
|
||||
|
@ -199,6 +200,8 @@ async fn main() {
|
|||
Lazy::force(&instance_info::INSTANCE_INFO);
|
||||
info!("Creating OAUTH client.");
|
||||
Lazy::force(&OAUTH_CLIENT);
|
||||
info!("Creating HTTP client.");
|
||||
CLIENT.set(generate_client(no_https_verification)).unwrap();
|
||||
|
||||
// Define default headers (added to all responses)
|
||||
app.default_headers = headers! {
|
||||
|
|
|
@ -70,7 +70,7 @@ impl Oauth {
|
|||
let request = builder.body(body).unwrap();
|
||||
|
||||
// Send request
|
||||
let client: client::Client<_, Body> = CLIENT.clone();
|
||||
let client: client::Client<_, Body> = CLIENT.get().unwrap().clone();
|
||||
let resp = client.request(request).await.ok()?;
|
||||
|
||||
// Parse headers - loid header _should_ be saved sent on subsequent token refreshes.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue