diff --git a/Cargo.toml b/Cargo.toml index 85d8c91..799fe68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,5 @@ log = "0.4.25" fern = "0.7.1" colored = "3.0.0" chrono = "0.4.39" +lazy_static = "1.5.0" +once_cell = "1.20.2" diff --git a/src/bin/main.rs b/src/bin/main.rs index 5ec8f15..74245e3 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -8,11 +8,15 @@ use async_trait::async_trait; use chrono::Local; use colored::{Color, Colorize}; use fern::Dispatch; +use lazy_static::lazy_static; use log::{Level, LevelFilter}; -use std::sync::Arc; +use once_cell::sync::{Lazy, OnceCell}; +use std::sync::{Arc, OnceLock}; +use tokio::task::JoinHandle; use LingTransmit::server::event::ServerEvent; use LingTransmit::server::Client::Client; use LingTransmit::server::Server; +use LingTransmit::shell::{register_command, start_shell, CommandActuators}; use LingTransmit::ssl::ServerCert; #[tokio::main] @@ -21,17 +25,30 @@ async fn main() { let pri = include_bytes!("../../ssl/test.pem"); let passwd = include_str!("../../ssl/pass.txt"); init_log(); - let server_cert = ServerCert::init_buffer_password( - &cert.to_vec(), - &pri.to_vec(), - passwd, + let server_cert = ServerCert::init_buffer_password(&cert.to_vec(), &pri.to_vec(), passwd) + .expect("解析证书失败"); + let server = Arc::new( + Server::new_tcp("0.0.0.0:11451", server_cert, Arc::new(Event {})) + .await + .expect("启动服务端失败"), + ); + let task = start_server(server.clone()); + register_command( + "exit".to_string(), + Box::new(ExitCommand::new(server.clone())), ) - .expect("解析证书失败"); - let server = Server::new_tcp("0.0.0.0:11451", server_cert, Arc::new(Event {})) - .await - .expect("启动服务端失败"); - server.start_accept().await; + .await + .expect("注册命令失败"); + start_shell().await; + task.await.expect("关闭服务器失败"); } + +fn start_server(server: Arc) -> JoinHandle<()> { + tokio::spawn(async move { + server.start_accept().await; + }) +} + fn get_time() -> String { let now = Local::now(); now.format("%Y-%m-%d %H:%M:%S").to_string() @@ -57,6 +74,24 @@ fn init_log() { .apply(); } +struct ExitCommand { + server: Arc, +} + +impl ExitCommand { + pub fn new(server: Arc) -> Self { + ExitCommand { server } + } +} + +#[async_trait] +impl CommandActuators for ExitCommand { + async fn execute(&self, command: String) { + self.server.close().await; + self.exit_shell(); + } +} + struct Event {} #[async_trait] diff --git a/src/lib.rs b/src/lib.rs index c80015b..b8edd5f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod server; pub mod close_sender; pub mod packet; -pub mod ssl; \ No newline at end of file +pub mod ssl; +pub mod shell; \ No newline at end of file diff --git a/src/server/accept.rs b/src/server/accept.rs index 24c1c67..696c7b9 100644 --- a/src/server/accept.rs +++ b/src/server/accept.rs @@ -30,7 +30,7 @@ impl OwnedWriteHalfAbstraction for tcp::OwnedWriteHalf {} impl OwnedWriteHalfAbstraction for unix::OwnedWriteHalf {} #[async_trait] -pub trait AcceptSocket { +pub trait AcceptSocket: Send + Sync { async fn accept( &self, ) -> io::Result<( diff --git a/src/server/mod.rs b/src/server/mod.rs index 47f6626..c04e6b9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -29,7 +29,7 @@ pub type ClientList = Arc>>>; /// 服务器抽象 pub struct Server { - listener: Box, + listener: Arc, close_sender: CloseSender, client_list: ClientList, next_id: AtomicU64, @@ -38,7 +38,7 @@ pub struct Server { } impl Server { - fn new(listener: Box, cert: ServerCert, event: Arc) -> Self { + fn new(listener: Arc, cert: ServerCert, event: Arc) -> Self { Server { listener, close_sender: CloseSender::new(), @@ -55,7 +55,7 @@ impl Server { event: Arc, ) -> io::Result { let listener = TcpListener::bind(addr).await?; - Ok(Server::new(Box::new(listener), cert, event)) + Ok(Server::new(Arc::new(listener), cert, event)) } pub async fn new_unix

( @@ -67,7 +67,7 @@ impl Server { P: AsRef, { let unix = UnixListener::bind(path)?; - Ok(Server::new(Box::new(unix), cert, event)) + Ok(Server::new(Arc::new(unix), cert, event)) } /// 广播关闭消息 diff --git a/src/shell/mod.rs b/src/shell/mod.rs new file mode 100644 index 0000000..c536f40 --- /dev/null +++ b/src/shell/mod.rs @@ -0,0 +1,81 @@ +// 版权所有 (c) ling 保留所有权利。 +// 除非另行说明,否则仅允许在LingTransmit中使用此文件中的代码。 +// +// 由 ling 创建于 2025/1/19. +#![allow(non_snake_case)] + +use colored::Colorize; +use lazy_static::lazy_static; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::io; +use std::io::Write; +use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::sync::Mutex; + +/// 命令驱动器 +#[async_trait::async_trait] +pub trait CommandActuators: Send + Sync { + async fn execute(&self, command: String); + + /// 调用后,shell将停止读取输入 + fn exit_shell(&self) { + SHELL_FLAG.store(false, Ordering::Release); + } +} + +lazy_static! { + static ref COMMAND_MAP: Lazy>>> = + Lazy::new(|| { Mutex::new(HashMap::new()) }); +} +static SHELL_FLAG: AtomicBool = AtomicBool::new(true); + +/// 启动shell +pub async fn start_shell() { + loop { + print_shell(); + let mut input = String::new(); + io::stdin().read_line(&mut input).expect("读取输入出错"); + let input = input.trim(); + run_command(input.to_string()).await; + if !SHELL_FLAG.load(Ordering::Acquire) { + break; + } + } +} + +/// 注册命令 +pub async fn register_command( + name: String, + actuators: Box, +) -> Result<(), String> { + let mut map = COMMAND_MAP.lock().await; + if let Some(_) = map.get(&name) { + return Err(format!("命令 {} 已经存在", name)); + } + map.insert(name, actuators); + Ok(()) +} + +/// 取消注册命令 +pub async fn unregister_command(name: &str) { + let mut map = COMMAND_MAP.lock().await; + map.remove(name); +} + +async fn run_command(command: String) { + if let Some(first_word) = command.split_whitespace().next() { + let map = COMMAND_MAP.lock().await; + if let Some(actuators) = map.get(first_word) { + actuators.execute(command).await; + } else { + println!("{}", format!("找不到命令:{}", first_word).red()); + } + } +} + +/// 打印提示符 +fn print_shell() { + print!("[root@ling] # "); + io::stdout().flush().unwrap(); +}