feat: host check, actually connecting

This commit is contained in:
DarkCat09 2024-10-18 21:42:08 +04:00
parent e80d4caec9
commit 468e13fa6e
Signed by: DarkCat09
GPG key ID: BD3CE9B65916CD82
4 changed files with 80 additions and 20 deletions

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use crate::{
config::Config,
error::{AppError, HandlerError, HandlerResult},
host,
};
use socks5_server::{
@ -10,7 +11,7 @@ use socks5_server::{
proto::{Address, Reply},
Command, IncomingConnection,
};
use tokio::net::TcpStream;
use tokio::{io::AsyncWriteExt, net::TcpStream};
pub async fn handler(
conn: IncomingConnection<(), NeedAuthenticate>,
@ -38,19 +39,52 @@ pub async fn handler(
Command::Connect(cmd, addr) => {
let target = match addr {
Address::DomainAddress(host, port) => {
let Ok(host) = String::from_utf8(host) else {
let conn = cmd
.reply(Reply::GeneralFailure, Address::unspecified())
.await?
.into_inner();
return Err(
(std::io::Error::from(std::io::ErrorKind::InvalidData), conn).into(),
);
let mut addrs = match host::resolve_domain(host, port, config.clone()) {
Ok(iter) => iter,
Err(e) => {
let conn = cmd
.reply(Reply::GeneralFailure, Address::unspecified())
.await?
.into_inner();
return Err((e, conn).into());
}
};
TcpStream::connect((host.as_ref(), port)).await
TcpStream::connect(addrs.next().unwrap()).await // TODO
}
Address::SocketAddress(addr) => {
let addr = match host::resolve_ip(addr, config.clone()) {
Ok(addr) => addr,
Err(e) => {
let conn = cmd
.reply(Reply::GeneralFailure, Address::unspecified())
.await?
.into_inner();
return Err((e, conn).into());
}
};
TcpStream::connect(addr).await
}
Address::SocketAddress(addr) => TcpStream::connect(addr).await,
};
let mut target = match target {
Ok(strm) => strm,
Err(e) => {
let conn = cmd
.reply(Reply::HostUnreachable, Address::unspecified())
.await?
.into_inner();
return Err((e, conn).into());
}
};
let mut conn = cmd.reply(Reply::Succeeded, Address::unspecified()).await?;
let res = tokio::io::copy_bidirectional(&mut target, &mut conn).await;
let _ = target.shutdown().await;
res.map_err(|e| (e, conn.into_inner()))?;
}
}

View file

@ -1,9 +1,35 @@
use std::{net::SocketAddr, sync::Arc};
use std::{
net::{SocketAddr, ToSocketAddrs},
sync::Arc,
};
use crate::{config::Config, error::AppResult};
use crate::{
config::Config,
error::{AppError, AppResult},
serde_addr::{AddrIter, TargetAddr, LOCALHOST_V4},
};
pub fn parse_str(value: Vec<u8>, config: Arc<Config>) -> AppResult<SocketAddr> {
let host = String::from_utf8(value)?;
pub fn resolve_domain(host: Vec<u8>, port: u16, config: Arc<Config>) -> AppResult<AddrIter> {
let host = String::from_utf8(host)?;
let target = TargetAddr::DomainName(host, port);
//
if config.local.contains(&target) {
Ok(SocketAddr::new(LOCALHOST_V4, port).into())
} else if let Some(remote) = config.remote.get(&target) {
Ok(remote.to_socket_addrs()?)
} else {
Err(AppError::ThirdPartyHost)
}
}
pub fn resolve_ip(addr: SocketAddr, config: Arc<Config>) -> AppResult<SocketAddr> {
let target = TargetAddr::IpAddress(addr);
if config.local.contains(&target) {
Ok(SocketAddr::new(LOCALHOST_V4, addr.port()))
} else if let Some(remote) = config.remote.get(&target) {
Ok(remote.to_socket_addrs()?.next().unwrap())
} else {
Err(AppError::ThirdPartyHost)
}
}

View file

@ -1,7 +1,7 @@
mod config;
mod error;
mod handler;
// mod host;
mod host;
mod serde_addr;
use std::sync::Arc;

View file

@ -2,6 +2,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use serde::{de::Visitor, Deserialize};
pub const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
pub const DEFAULT_PORT: u16 = 5232;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
@ -137,13 +138,12 @@ impl From<SocketAddr> for AddrIter {
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::{IpAddr, Ipv6Addr};
use serde_test::{assert_de_tokens, Token};
use super::{TargetAddr, DEFAULT_PORT};
use super::{TargetAddr, DEFAULT_PORT, LOCALHOST_V4};
const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const TEST_V6: IpAddr = IpAddr::V6(Ipv6Addr::new(
0xfe80, 0x0, 0x0, 0x0, 0x721e, 0xd21f, 0x29a3, 0xf396,
));