From 100a7b65a6a79968cbc8548f4ce12d722dfb0cba Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Sat, 23 Nov 2024 21:17:52 -0500 Subject: [PATCH] fix(client): update headers management, add self check (fix #334, fix #318) --- src/client.rs | 41 ++++++++++++++++++++++++++++++++++++++--- src/main.rs | 12 +++++++++++- src/oauth.rs | 6 ++++-- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/src/client.rs b/src/client.rs index 1e1661d..0e2c301 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,6 +19,7 @@ use std::{io, result::Result}; use crate::dbg_msg; use crate::oauth::{force_refresh_token, token_daemon, Oauth}; use crate::server::RequestExt; +use crate::subreddit::community; use crate::utils::format_url; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; @@ -235,13 +236,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo { let client = OAUTH_CLIENT.load_full(); - for (key, value) in client.initial_headers.clone() { + for (key, value) in client.headers_map.clone() { headers.push((key, value)); } } - trace!("Headers: {:#?}", headers); - // shuffle headers: https://github.com/redlib-org/redlib/issues/324 fastrand::shuffle(&mut headers); @@ -390,6 +389,12 @@ pub async fn json(path: String, quarantine: bool) -> Result { "Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. Resets in {reset}. Rollover: {}. Ratelimit used: {used}", if is_rolling_over { "yes" } else { "no" }, ); + + // If can parse remaining as a float, round to a u16 and save + if let Ok(val) = remaining.parse::() { + OAUTH_RATELIMIT_REMAINING.store(val.round() as u16, Ordering::SeqCst); + } + Some(reset) } else { None @@ -474,6 +479,36 @@ pub async fn json(path: String, quarantine: bool) -> Result { } } +async fn self_check(sub: &str) -> Result<(), String> { + let request = Request::get(format!("/r/{sub}/")).body(Body::empty()).unwrap(); + + match community(request).await { + Ok(sub) if sub.status().is_success() => Ok(()), + Ok(sub) => Err(sub.status().to_string()), + Err(e) => Err(e), + } +} + +pub async fn rate_limit_check() -> Result<(), String> { + // First, check a subreddit. + self_check("reddit").await?; + // This will reduce the rate limit to 99. Assert this check. + if OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst) != 99 { + return Err(format!("Rate limit check failed: expected 99, got {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst))); + } + // Now, we switch out the OAuth client. + // This checks for the IP rate limit association. + force_refresh_token().await; + // Now, check a new sub to break cache. + self_check("rust").await?; + // Again, assert the rate limit check. + if OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst) != 99 { + return Err(format!("Rate limit check failed: expected 99, got {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst))); + } + + Ok(()) +} + #[cfg(test)] static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL"; diff --git a/src/main.rs b/src/main.rs index 7342597..abae968 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ use hyper::Uri; use hyper::{header::HeaderValue, Body, Request, Response}; use log::info; use once_cell::sync::Lazy; -use redlib::client::{canonical_path, proxy, CLIENT}; +use redlib::client::{canonical_path, proxy, rate_limit_check, CLIENT}; use redlib::server::{self, RequestExt}; use redlib::utils::{error, redirect, ThemeAssets}; use redlib::{config, duplicates, headers, instance_info, post, search, settings, subreddit, user}; @@ -146,6 +146,16 @@ async fn main() { ) .get_matches(); + match rate_limit_check().await { + Ok(()) => { + info!("[✅] Rate limit check passed"); + }, + Err(e) => { + log::error!("[❌] Rate limit check failed: {}", e); + std::process::exit(1); + } + } + let address = matches.get_one::("address").unwrap(); let port = matches.get_one::("port").unwrap(); let hsts = matches.get_one("hsts").map(|m: &String| m.as_str()); diff --git a/src/oauth.rs b/src/oauth.rs index 576b647..12b0f37 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -38,12 +38,12 @@ impl Oauth { } Ok(None) => { error!("Failed to create OAuth client. Retrying in 5 seconds..."); - continue; } Err(duration) => { error!("Failed to create OAuth client in {duration:?}. Retrying in 5 seconds..."); } } + tokio::time::sleep(Duration::from_secs(5)).await; } } @@ -91,13 +91,14 @@ impl Oauth { // Build request let request = builder.body(body).unwrap(); - trace!("Sending token request..."); + trace!("Sending token request...\n\n{request:?}"); // Send request let client: &once_cell::sync::Lazy> = &CLIENT; let resp = client.request(request).await.ok()?; trace!("Received response with status {} and length {:?}", resp.status(), resp.headers().get("content-length")); + trace!("OAuth headers: {:#?}", resp.headers()); // Parse headers - loid header _should_ be saved sent on subsequent token refreshes. // Technically it's not needed, but it's easy for Reddit API to check for this. @@ -200,6 +201,7 @@ impl Device { ("x-reddit-media-codecs".into(), codecs), ("Content-Type".into(), "application/json; charset=UTF-8".into()), ("client-vendor-id".into(), uuid.clone()), + ("X-Reddit-Device-Id".into(), uuid.clone()), ]); info!("[🔄] Spoofing Android client with headers: {headers:?}, uuid: \"{uuid}\", and OAuth ID \"{REDDIT_ANDROID_OAUTH_CLIENT_ID}\"");