feat: host check, actually connecting
This commit is contained in:
parent
e80d4caec9
commit
468e13fa6e
4 changed files with 80 additions and 20 deletions
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||
use crate::{
|
||||
config::Config,
|
||||
error::{AppError, HandlerError, HandlerResult},
|
||||
host,
|
||||
};
|
||||
|
||||
use socks5_server::{
|
||||
|
@ -10,7 +11,7 @@ use socks5_server::{
|
|||
proto::{Address, Reply},
|
||||
Command, IncomingConnection,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::{io::AsyncWriteExt, net::TcpStream};
|
||||
|
||||
pub async fn handler(
|
||||
conn: IncomingConnection<(), NeedAuthenticate>,
|
||||
|
@ -38,19 +39,52 @@ pub async fn handler(
|
|||
Command::Connect(cmd, addr) => {
|
||||
let target = match addr {
|
||||
Address::DomainAddress(host, port) => {
|
||||
let Ok(host) = String::from_utf8(host) else {
|
||||
let mut addrs = match host::resolve_domain(host, port, config.clone()) {
|
||||
Ok(iter) => iter,
|
||||
Err(e) => {
|
||||
let conn = cmd
|
||||
.reply(Reply::GeneralFailure, Address::unspecified())
|
||||
.await?
|
||||
.into_inner();
|
||||
return Err(
|
||||
(std::io::Error::from(std::io::ErrorKind::InvalidData), conn).into(),
|
||||
);
|
||||
};
|
||||
TcpStream::connect((host.as_ref(), port)).await
|
||||
return Err((e, conn).into());
|
||||
}
|
||||
Address::SocketAddress(addr) => TcpStream::connect(addr).await,
|
||||
};
|
||||
|
||||
TcpStream::connect(addrs.next().unwrap()).await // TODO
|
||||
}
|
||||
Address::SocketAddress(addr) => {
|
||||
let addr = match host::resolve_ip(addr, config.clone()) {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
let conn = cmd
|
||||
.reply(Reply::GeneralFailure, Address::unspecified())
|
||||
.await?
|
||||
.into_inner();
|
||||
return Err((e, conn).into());
|
||||
}
|
||||
};
|
||||
|
||||
TcpStream::connect(addr).await
|
||||
}
|
||||
};
|
||||
|
||||
let mut target = match target {
|
||||
Ok(strm) => strm,
|
||||
Err(e) => {
|
||||
let conn = cmd
|
||||
.reply(Reply::HostUnreachable, Address::unspecified())
|
||||
.await?
|
||||
.into_inner();
|
||||
return Err((e, conn).into());
|
||||
}
|
||||
};
|
||||
|
||||
let mut conn = cmd.reply(Reply::Succeeded, Address::unspecified()).await?;
|
||||
|
||||
let res = tokio::io::copy_bidirectional(&mut target, &mut conn).await;
|
||||
let _ = target.shutdown().await;
|
||||
|
||||
res.map_err(|e| (e, conn.into_inner()))?;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
36
src/host.rs
36
src/host.rs
|
@ -1,9 +1,35 @@
|
|||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
net::{SocketAddr, ToSocketAddrs},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{config::Config, error::AppResult};
|
||||
use crate::{
|
||||
config::Config,
|
||||
error::{AppError, AppResult},
|
||||
serde_addr::{AddrIter, TargetAddr, LOCALHOST_V4},
|
||||
};
|
||||
|
||||
pub fn parse_str(value: Vec<u8>, config: Arc<Config>) -> AppResult<SocketAddr> {
|
||||
let host = String::from_utf8(value)?;
|
||||
pub fn resolve_domain(host: Vec<u8>, port: u16, config: Arc<Config>) -> AppResult<AddrIter> {
|
||||
let host = String::from_utf8(host)?;
|
||||
let target = TargetAddr::DomainName(host, port);
|
||||
|
||||
//
|
||||
if config.local.contains(&target) {
|
||||
Ok(SocketAddr::new(LOCALHOST_V4, port).into())
|
||||
} else if let Some(remote) = config.remote.get(&target) {
|
||||
Ok(remote.to_socket_addrs()?)
|
||||
} else {
|
||||
Err(AppError::ThirdPartyHost)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_ip(addr: SocketAddr, config: Arc<Config>) -> AppResult<SocketAddr> {
|
||||
let target = TargetAddr::IpAddress(addr);
|
||||
|
||||
if config.local.contains(&target) {
|
||||
Ok(SocketAddr::new(LOCALHOST_V4, addr.port()))
|
||||
} else if let Some(remote) = config.remote.get(&target) {
|
||||
Ok(remote.to_socket_addrs()?.next().unwrap())
|
||||
} else {
|
||||
Err(AppError::ThirdPartyHost)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mod config;
|
||||
mod error;
|
||||
mod handler;
|
||||
// mod host;
|
||||
mod host;
|
||||
mod serde_addr;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
|
|||
|
||||
use serde::{de::Visitor, Deserialize};
|
||||
|
||||
pub const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
pub const DEFAULT_PORT: u16 = 5232;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
|
@ -137,13 +138,12 @@ impl From<SocketAddr> for AddrIter {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::net::{IpAddr, Ipv6Addr};
|
||||
|
||||
use serde_test::{assert_de_tokens, Token};
|
||||
|
||||
use super::{TargetAddr, DEFAULT_PORT};
|
||||
use super::{TargetAddr, DEFAULT_PORT, LOCALHOST_V4};
|
||||
|
||||
const LOCALHOST_V4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||
const TEST_V6: IpAddr = IpAddr::V6(Ipv6Addr::new(
|
||||
0xfe80, 0x0, 0x0, 0x0, 0x721e, 0xd21f, 0x29a3, 0xf396,
|
||||
));
|
||||
|
|
Loading…
Add table
Reference in a new issue