diff --git a/.gitignore b/.gitignore index 408b8a5..ecc9bed 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target Cargo.lock -.idea \ No newline at end of file +.idea +ssl \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index e257768..85d8c91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,9 +4,8 @@ version = "0.1.0" edition = "2021" [lib] -name = "LingTransmit" -crate-type = ["staticlib"] - +crate-type = ["rlib", "staticlib"] +path = "src/lib.rs" [dependencies] openssl = "0.10.68" @@ -14,3 +13,6 @@ tokio = { version = "1.43.0", features = ["full"] } async-trait = "0.1.85" tokio-macros = "2.5.0" log = "0.4.25" +fern = "0.7.1" +colored = "3.0.0" +chrono = "0.4.39" diff --git a/src/bin/main.rs b/src/bin/main.rs index 2e21c16..5ec8f15 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -4,4 +4,77 @@ // 由 ling 创建于 2025/1/18. #![allow(non_snake_case)] -fn main() {} +use async_trait::async_trait; +use chrono::Local; +use colored::{Color, Colorize}; +use fern::Dispatch; +use log::{Level, LevelFilter}; +use std::sync::Arc; +use LingTransmit::server::event::ServerEvent; +use LingTransmit::server::Client::Client; +use LingTransmit::server::Server; +use LingTransmit::ssl::ServerCert; + +#[tokio::main] +async fn main() { + let cert = include_bytes!("../../ssl/test_cert.pem"); + let pri = include_bytes!("../../ssl/test.pem"); + let passwd = include_str!("../../ssl/pass.txt"); + init_log(); + let server_cert = ServerCert::init_buffer_password( + &cert.to_vec(), + &pri.to_vec(), + passwd, + ) + .expect("解析证书失败"); + let server = Server::new_tcp("0.0.0.0:11451", server_cert, Arc::new(Event {})) + .await + .expect("启动服务端失败"); + server.start_accept().await; +} +fn get_time() -> String { + let now = Local::now(); + now.format("%Y-%m-%d %H:%M:%S").to_string() +} +fn init_log() { + let console_dispatch = Dispatch::new() + .format(|out, message, record| { + let (title, color) = match record.level() { + Level::Error => ("Error", Color::Red), + Level::Warn => ("Warning", Color::Yellow), + Level::Info => ("Info", Color::Green), + Level::Debug => ("Debug", Color::BrightWhite), + Level::Trace => ("Trace", Color::White), + }; + + out.finish(format_args!( + "{}", + format!("[{} {}]\t{}", get_time(), title, message).color(color) + )) + }) + .chain(std::io::stdout()) + .level(LevelFilter::Trace) + .apply(); +} + +struct Event {} + +#[async_trait] +impl ServerEvent for Event { + async fn client_linker_listener(&self, client: Arc) { + println!("客户端连入,ID:{}", client.id); + } + + async fn client_close_listener(&self, client: Arc) { + println!("客户端挂断,ID:{}", client.id) + } + + async fn client_user_data(&self, client: Arc, packet: Vec) -> std::io::Result<()> { + println!( + "客户端发送数据,ID:{},数据长度:{}", + client.id, + packet.len() + ); + Ok(()) + } +} diff --git a/src/packet/mod.rs b/src/packet/mod.rs index f17df74..688d76c 100644 --- a/src/packet/mod.rs +++ b/src/packet/mod.rs @@ -11,6 +11,7 @@ use crate::server::accept::OwnedReadHalfAbstraction; use crate::server::Client::Client; use std::fmt::format; use std::sync::atomic::Ordering; +use std::sync::Arc; use std::{error, io}; use tokio::io::AsyncReadExt; @@ -27,7 +28,7 @@ pub enum NetworkPackets { /// 从字节流中读取一个数据包 pub async fn read_packet( read: &mut Box, - client: &Client, + client: Arc, ) -> io::Result { //根据协议,通信伊始收到 SYN V1 信息,则使用V1协议加密通信 //在此实现中,彻底摈弃了未加密的不安全数据。 @@ -35,7 +36,7 @@ pub async fn read_packet( if client.syn_version.load(Ordering::Acquire) == 0 && !client.is_key_negotiation.load(Ordering::Acquire) { - let syn = read.read_i32().await?; + let syn = read.read_i32_le().await?; if syn != LING_SYN_V1 { return Err(make_error("客户端尝试使用未加密连接交换数据")); } @@ -44,7 +45,7 @@ pub async fn read_packet( return Ok(NetworkPackets::SynV1); } //读取数据包长度 - let mut size = read.read_i32().await?; + let mut size = read.read_i32_le().await?; // 在旧版本的协议中,没有魔数标记数据包开始,数据包通过 int32_t 类型的size作为开始。 // 如果传输过程中出错,导致读取位置发生些许偏差,则size不准,可能错误分配堆内存,带来安全隐患。 // 自 Ling V1.1开始,数据包头部必须使用 LING_START 开始 @@ -56,7 +57,7 @@ pub async fn read_packet( )); } //由于size实际上读取的是 LING_START 魔数,所以要重新读取四个字节作为数据包实际大小 - size = read.read_i32().await?; + size = read.read_i32_le().await?; //附带魔数后,将数据包长度限制提高到 512kb if size > MAX_SIZE_V1_1 { return Err(make_error( @@ -73,7 +74,7 @@ pub async fn read_packet( } // 数据包类型 - let data_type = match read.read_i32().await? { + let data_type = match read.read_i32_le().await? { API_TYPE_ASK => API_TYPE_ASK, API_TYPE_PUSH_AES_KEY => API_TYPE_PUSH_AES_KEY, num => { @@ -91,6 +92,15 @@ pub async fn read_packet( ))); } + // 根据协议,每一个数据包都必须以 LING_STOP 作为结束 + let stop = read.read_i32_le().await?; + if stop != LING_STOP { + return Err(make_error(format!( + "数据包结束标记错误,需要 {:X},但发现 {:X}", + LING_STOP, stop + ))); + } + match data_type { API_TYPE_ASK => Ok(NetworkPackets::UserAsk(buffer)), API_TYPE_PUSH_AES_KEY => Ok(NetworkPackets::PushAesKey(buffer)), diff --git a/src/server/Client.rs b/src/server/Client.rs index b4e6ce8..f3575c6 100644 --- a/src/server/Client.rs +++ b/src/server/Client.rs @@ -8,6 +8,7 @@ use crate::close_sender::CloseSender; use crate::packet::code::{SERVER_ACK, SERVER_ERROR}; use crate::packet::{read_packet, NetworkPackets}; use crate::server::accept::{OwnedReadHalfAbstraction, OwnedWriteHalfAbstraction, SocketAddr}; +use crate::server::event::ServerEvent; use crate::server::ClientID; use crate::ssl::ServerCert; use log::{error, info}; @@ -41,6 +42,7 @@ pub struct Client { pub key: OnceLock, /// 服务器证书 cert: Arc, + pub(super) event: Arc, } impl Client { @@ -52,6 +54,7 @@ impl Client { id: ClientID, addr: SocketAddr, cert: Arc, + event: Arc, ) -> Self { Client { server_close, @@ -64,6 +67,7 @@ impl Client { is_key_negotiation: AtomicBool::new(false), key: OnceLock::new(), cert, + event, } } @@ -79,7 +83,7 @@ impl Client { } /// 开始处理该客户端的请求 - pub async fn start(&self) { + pub async fn start(self: Arc) { loop { //从字节流中读取一个数据包 let packet = tokio::select! { @@ -90,14 +94,17 @@ impl Client { Err(io::Error::new(io::ErrorKind::NotFound,"读取端已经被挂断")) } Some(ref mut val) => { - read_packet(val,self).await + read_packet(val,self.clone()).await } } } => { match packet { Ok(val) => {val} Err(err) => { - info!("{} 号连接读取数据包出错:{}",self.id,err.to_string()); + // 如果仅仅是读取到文件尾,说明客户端只是挂断了,不是什么大问题 + if err.kind() != io::ErrorKind::UnexpectedEof { + error!("{} 号连接读取数据包出错:{}",self.id,err.to_string()); + } return; } } @@ -130,10 +137,10 @@ impl Client { } /// 处理客户端请求 - async fn process_packet(&self, packet: NetworkPackets) -> io::Result<()> { + async fn process_packet(self: &Arc, packet: NetworkPackets) -> io::Result<()> { match packet { NetworkPackets::SynV1 => self.syn_v1().await, - NetworkPackets::UserAsk(_) => Ok(()), + NetworkPackets::UserAsk(data) => self.event.client_user_data(self.clone(), data).await, NetworkPackets::PushAesKey(data) => self.client_push_key(&data).await, } } @@ -157,10 +164,10 @@ impl Client { let mut send = self.write_soc.lock().await; if let Err(_) = self.key.set(key.clone()) { //重复推送,拒绝密钥 - send.write_i32(SERVER_ERROR).await?; + send.write_i32_le(SERVER_ERROR).await?; return Ok(()); } - send.write_i32(SERVER_ACK).await?; + send.write_i32_le(SERVER_ACK).await?; Ok(()) } @@ -168,7 +175,7 @@ impl Client { async fn syn_v1(&self) -> io::Result<()> { let certificate = self.cert.certificate.to_pem()?; let mut send = self.write_soc.lock().await; - send.write_i64(certificate.len() as i64).await?; + send.write_i64_le(certificate.len() as i64).await?; send.write(&certificate).await?; Ok(()) } diff --git a/src/server/accept.rs b/src/server/accept.rs index eb1dce8..24c1c67 100644 --- a/src/server/accept.rs +++ b/src/server/accept.rs @@ -9,53 +9,13 @@ use tokio::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::{tcp, unix, TcpListener, UnixListener}; -/// 读取抽象,使用小端序 +/// 读取抽象 #[async_trait] -pub trait OwnedReadHalfAbstraction: AsyncRead + Unpin + Send + Sync { - async fn read_i32(&mut self) -> io::Result { - let mut buffer = [0u8; 4]; - self.read_exact(&mut buffer).await?; - Ok(i32::from_le_bytes(buffer)) - } +pub trait OwnedReadHalfAbstraction: AsyncRead + Unpin + Send + Sync {} - async fn read_i64(&mut self) -> io::Result { - let mut buffer = [0u8; 8]; - self.read_exact(&mut buffer).await?; - Ok(i64::from_le_bytes(buffer)) - } - - async fn read_u32(&mut self) -> io::Result { - let mut buffer = [0u8; 4]; - self.read_exact(&mut buffer).await?; - Ok(u32::from_le_bytes(buffer)) - } - - async fn read_u64(&mut self) -> io::Result { - let mut buffer = [0u8; 8]; - self.read_exact(&mut buffer).await?; - Ok(u64::from_le_bytes(buffer)) - } -} - -/// 写入抽象,使用小端序 +/// 写入抽象 #[async_trait] -pub trait OwnedWriteHalfAbstraction: AsyncWrite + Unpin + Send + Sync { - async fn write_i32(&mut self, value: i32) -> io::Result { - self.write(&value.to_le_bytes()).await - } - - async fn write_i64(&mut self, value: i64) -> io::Result { - self.write(&value.to_le_bytes()).await - } - - async fn write_u32(&mut self, value: u32) -> io::Result { - self.write(&value.to_le_bytes()).await - } - - async fn write_u64(&mut self, value: u64) -> io::Result { - self.write(&value.to_le_bytes()).await - } -} +pub trait OwnedWriteHalfAbstraction: AsyncWrite + Unpin + Send + Sync {} #[async_trait] impl OwnedReadHalfAbstraction for tcp::OwnedReadHalf {} diff --git a/src/server/event.rs b/src/server/event.rs new file mode 100644 index 0000000..1d3c14c --- /dev/null +++ b/src/server/event.rs @@ -0,0 +1,28 @@ +// 版权所有 (c) ling 保留所有权利。 +// 除非另行说明,否则仅允许在LingTransmit中使用此文件中的代码。 +// +// 由 ling 创建于 2025/1/19. +#![allow(non_snake_case)] + +use crate::server::Client; +use async_trait::async_trait; +use std::io; +use std::sync::Arc; +use tokio::sync::broadcast; +use tokio::sync::broadcast::error::{RecvError, SendError}; + +#[async_trait] +pub trait ServerEvent: Send + Sync { + /// 客户端连入事件 + async fn client_linker_listener(&self, client: Arc); + + /// 客户端挂断事件 + async fn client_close_listener(&self, client: Arc); + + /// 客户端请求 + async fn client_user_data( + &self, + client: Arc, + packet: Vec, + ) -> io::Result<()>; +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 0155840..47f6626 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,9 +6,11 @@ pub mod Client; pub mod accept; +pub mod event; use crate::close_sender::CloseSender; use crate::server::accept::AcceptSocket; +use crate::server::event::ServerEvent; use crate::ssl::ServerCert; use async_trait::async_trait; use log::{debug, error}; @@ -32,30 +34,40 @@ pub struct Server { client_list: ClientList, next_id: AtomicU64, cert: Arc, + event: Arc, } impl Server { - fn new(listener: Box, cert: ServerCert) -> Self { + fn new(listener: Box, cert: ServerCert, event: Arc) -> Self { Server { listener, close_sender: CloseSender::new(), client_list: Arc::new(Mutex::new(HashMap::new())), next_id: AtomicU64::new(0), cert: Arc::new(cert), + event, } } - pub async fn new_tcp(addr: A, cert: ServerCert) -> io::Result { + pub async fn new_tcp( + addr: A, + cert: ServerCert, + event: Arc, + ) -> io::Result { let listener = TcpListener::bind(addr).await?; - Ok(Server::new(Box::new(listener), cert)) + Ok(Server::new(Box::new(listener), cert, event)) } - pub async fn new_unix

(path: P, cert: ServerCert) -> io::Result + pub async fn new_unix

( + path: P, + cert: ServerCert, + event: Arc, + ) -> io::Result where P: AsRef, { let unix = UnixListener::bind(path)?; - Ok(Server::new(Box::new(unix), cert)) + Ok(Server::new(Box::new(unix), cert, event)) } /// 广播关闭消息 @@ -69,15 +81,18 @@ impl Server { } /// 挂断一个客户端 - pub async fn close_client(&self, id: ClientID) { - Self::close_client_form_arc(&self.client_list, id).await; + pub async fn close_client(&self, client: &Arc) { + Self::close_client_form_arc(&self.client_list, client).await; } - pub async fn close_client_form_arc(list: &ClientList, id: ClientID) { + pub async fn close_client_form_arc(list: &ClientList, client: &Arc) { + let client_id = client.id; let mut lock = list.lock().await; - if let Some(client) = lock.get(&id) { + if let Some(client) = lock.get(&client_id) { + //向使用者报告客户端关闭 + client.event.client_close_listener(client.clone()).await; client.close().await; - lock.remove(&id); + lock.remove(&client_id); } } @@ -109,17 +124,21 @@ impl Server { id, addr, self.cert.clone(), + self.event.clone(), )); let mut lock = self.client_list.lock().await; lock.insert(id, client.clone()); drop(lock); + //向使用者报告新的客户端连入 + self.event.client_linker_listener(client.clone()).await; + let list = self.get_client_list(); tokio::spawn(async move { - client.start().await; + client.clone().start().await; //当连接的事件轮退出,则自动挂断 - Self::close_client_form_arc(&list, id).await; + Self::close_client_form_arc(&list, &client).await; }); Ok(()) }