diff --git a/src/client.rs b/src/client.rs index 1e1661d..aebbe01 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,8 +2,10 @@ use arc_swap::ArcSwap; use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; +use hyper::body::HttpBody; use hyper::client::HttpConnector; use hyper::header::HeaderValue; +use hyper::StatusCode; use hyper::{body, body::Buf, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; use libflate::gzip; @@ -12,8 +14,8 @@ use once_cell::sync::Lazy; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; -use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, AtomicU16}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::{io, result::Result}; use crate::dbg_msg; @@ -45,6 +47,8 @@ pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); +pub static REQUEST_COUNT: AtomicU64 = AtomicU64::new(0); + static URL_PAIRS: [(&str, &str); 2] = [ (ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST), (REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST), @@ -240,8 +244,6 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo } } - trace!("Headers: {:#?}", headers); - // shuffle headers: https://github.com/redlib-org/redlib/issues/324 fastrand::shuffle(&mut headers); @@ -375,13 +377,24 @@ pub async fn json(path: String, quarantine: bool) -> Result { tokio::spawn(force_refresh_token()); } OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst); + REQUEST_COUNT.fetch_add(1, Ordering::Relaxed); // Fetch the url... match reddit_get(path.clone(), quarantine).await { Ok(response) => { let status = response.status(); - let reset: Option = if let (Some(remaining), Some(reset), Some(used)) = ( + if status == StatusCode::FORBIDDEN { + error!("Generic network policy error. Total requests: {}", REQUEST_COUNT.load(Ordering::Relaxed)); + let mut arr = vec![]; + let mut body = response.collect().await.unwrap_or_default().aggregate(); + body.copy_to_slice(&mut arr); + let body_str = String::from_utf8_lossy(&arr); + trace!("Network policy error body: \n{body_str}"); + return Err(format!("Generic network policy error. Total requests: {}", REQUEST_COUNT.load(Ordering::Relaxed))); + } + + if let (Some(remaining), Some(reset), Some(used)) = ( response.headers().get("x-ratelimit-remaining").and_then(|val| val.to_str().ok().map(|s| s.to_string())), response.headers().get("x-ratelimit-reset").and_then(|val| val.to_str().ok().map(|s| s.to_string())), response.headers().get("x-ratelimit-used").and_then(|val| val.to_str().ok().map(|s| s.to_string())), @@ -390,27 +403,16 @@ pub async fn json(path: String, quarantine: bool) -> Result { "Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. Resets in {reset}. Rollover: {}. Ratelimit used: {used}", if is_rolling_over { "yes" } else { "no" }, ); - Some(reset) - } else { - None - }; + + if let Ok(remaining) = remaining.parse::() { + OAUTH_RATELIMIT_REMAINING.store(remaining, Ordering::SeqCst); + } + } // asynchronously aggregate the chunks of the body - match hyper::body::aggregate(response).await { + match response.collect().await { Ok(body) => { - let has_remaining = body.has_remaining(); - - if !has_remaining { - // Rate limited, so spawn a force_refresh_token() - tokio::spawn(force_refresh_token()); - return match reset { - Some(val) => Err(format!( - "Reddit rate limit exceeded. Try refreshing in a few seconds.\ - Rate limit will reset in: {val}" - )), - None => Err("Reddit rate limit exceeded".to_string()), - }; - } + let body = body.aggregate(); // Parse the response from Reddit as JSON match serde_json::from_reader(body.reader()) { diff --git a/src/main.rs b/src/main.rs index 799b491..b66893f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -385,7 +385,7 @@ async fn main() { error!("Please update to the latest version. Then, check the issue tracker for the latest error."); error!("https://github.com/redlib-org/redlib/issues"); panic!("Self-test failed"); - }, + } } } diff --git a/src/oauth.rs b/src/oauth.rs index 576b647..9a4fa5e 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -38,12 +38,12 @@ impl Oauth { } Ok(None) => { error!("Failed to create OAuth client. Retrying in 5 seconds..."); - continue; } Err(duration) => { error!("Failed to create OAuth client in {duration:?}. Retrying in 5 seconds..."); } } + tokio::time::sleep(Duration::from_secs(5)).await; } } @@ -91,7 +91,7 @@ impl Oauth { // Build request let request = builder.body(body).unwrap(); - trace!("Sending token request..."); + trace!("Sending token request...\n\n{request:#?}"); // Send request let client: &once_cell::sync::Lazy> = &CLIENT;