From 468e13fa6e2e6d51f024d872517f055bcb79ac06 Mon Sep 17 00:00:00 2001 From: DarkCat09 Date: Fri, 18 Oct 2024 21:42:08 +0400 Subject: [PATCH] feat: host check, actually connecting --- src/handler.rs | 56 +++++++++++++++++++++++++++++++++++++---------- src/host.rs | 36 +++++++++++++++++++++++++----- src/main.rs | 2 +- src/serde_addr.rs | 6 ++--- 4 files changed, 80 insertions(+), 20 deletions(-) diff --git a/src/handler.rs b/src/handler.rs index b6dd21f..88fda29 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::{ config::Config, error::{AppError, HandlerError, HandlerResult}, + host, }; use socks5_server::{ @@ -10,7 +11,7 @@ use socks5_server::{ proto::{Address, Reply}, Command, IncomingConnection, }; -use tokio::net::TcpStream; +use tokio::{io::AsyncWriteExt, net::TcpStream}; pub async fn handler( conn: IncomingConnection<(), NeedAuthenticate>, @@ -38,19 +39,52 @@ pub async fn handler( Command::Connect(cmd, addr) => { let target = match addr { Address::DomainAddress(host, port) => { - let Ok(host) = String::from_utf8(host) else { - let conn = cmd - .reply(Reply::GeneralFailure, Address::unspecified()) - .await? - .into_inner(); - return Err( - (std::io::Error::from(std::io::ErrorKind::InvalidData), conn).into(), - ); + let mut addrs = match host::resolve_domain(host, port, config.clone()) { + Ok(iter) => iter, + Err(e) => { + let conn = cmd + .reply(Reply::GeneralFailure, Address::unspecified()) + .await? + .into_inner(); + return Err((e, conn).into()); + } }; - TcpStream::connect((host.as_ref(), port)).await + + TcpStream::connect(addrs.next().unwrap()).await // TODO + } + Address::SocketAddress(addr) => { + let addr = match host::resolve_ip(addr, config.clone()) { + Ok(addr) => addr, + Err(e) => { + let conn = cmd + .reply(Reply::GeneralFailure, Address::unspecified()) + .await? + .into_inner(); + return Err((e, conn).into()); + } + }; + + TcpStream::connect(addr).await } - Address::SocketAddress(addr) => TcpStream::connect(addr).await, }; + + let mut target = match target { + Ok(strm) => strm, + Err(e) => { + let conn = cmd + .reply(Reply::HostUnreachable, Address::unspecified()) + .await? + .into_inner(); + return Err((e, conn).into()); + } + }; + + let mut conn = cmd.reply(Reply::Succeeded, Address::unspecified()).await?; + + let res = tokio::io::copy_bidirectional(&mut target, &mut conn).await; + let _ = target.shutdown().await; + + res.map_err(|e| (e, conn.into_inner()))?; } } diff --git a/src/host.rs b/src/host.rs index 28c5d4d..57ecafc 100644 --- a/src/host.rs +++ b/src/host.rs @@ -1,9 +1,35 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{SocketAddr, ToSocketAddrs}, + sync::Arc, +}; -use crate::{config::Config, error::AppResult}; +use crate::{ + config::Config, + error::{AppError, AppResult}, + serde_addr::{AddrIter, TargetAddr, LOCALHOST_V4}, +}; -pub fn parse_str(value: Vec, config: Arc) -> AppResult { - let host = String::from_utf8(value)?; +pub fn resolve_domain(host: Vec, port: u16, config: Arc) -> AppResult { + let host = String::from_utf8(host)?; + let target = TargetAddr::DomainName(host, port); - // + if config.local.contains(&target) { + Ok(SocketAddr::new(LOCALHOST_V4, port).into()) + } else if let Some(remote) = config.remote.get(&target) { + Ok(remote.to_socket_addrs()?) + } else { + Err(AppError::ThirdPartyHost) + } +} + +pub fn resolve_ip(addr: SocketAddr, config: Arc) -> AppResult { + let target = TargetAddr::IpAddress(addr); + + if config.local.contains(&target) { + Ok(SocketAddr::new(LOCALHOST_V4, addr.port())) + } else if let Some(remote) = config.remote.get(&target) { + Ok(remote.to_socket_addrs()?.next().unwrap()) + } else { + Err(AppError::ThirdPartyHost) + } } diff --git a/src/main.rs b/src/main.rs index c0922d0..6836dd6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod config; mod error; mod handler; -// mod host; +mod host; mod serde_addr; use std::sync::Arc; diff --git a/src/serde_addr.rs b/src/serde_addr.rs index 496e9b2..6ce1848 100644 --- a/src/serde_addr.rs +++ b/src/serde_addr.rs @@ -2,6 +2,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}; use serde::{de::Visitor, Deserialize}; +pub const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 5232; #[derive(Debug, Clone, Eq, PartialEq, Hash)] @@ -137,13 +138,12 @@ impl From for AddrIter { #[cfg(test)] mod tests { - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + use std::net::{IpAddr, Ipv6Addr}; use serde_test::{assert_de_tokens, Token}; - use super::{TargetAddr, DEFAULT_PORT}; + use super::{TargetAddr, DEFAULT_PORT, LOCALHOST_V4}; - const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); const TEST_V6: IpAddr = IpAddr::V6(Ipv6Addr::new( 0xfe80, 0x0, 0x0, 0x0, 0x721e, 0xd21f, 0x29a3, 0xf396, ));