From 67f9114767b4a1621a93fe5f0224ad5f0310dbc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=BB=E9=AD=82=E5=9C=A3=E4=BD=BF?= Date: Sun, 23 Feb 2025 22:15:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BC=A0=E8=BE=93=E8=BF=87?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bin/client.rs | 16 ++++++ src/bin/logger/mod.rs | 36 ++++++++++++++ src/bin/main.rs | 37 ++++---------- src/client/mod.rs | 113 ++++++++++++++++++++++++++++++++++++++++-- src/packet/code.rs | 12 +++-- src/packet/mod.rs | 2 +- src/server/Client.rs | 107 +++++++++++++++++++++++++++++++++++---- src/stream.rs | 2 +- 8 files changed, 279 insertions(+), 46 deletions(-) create mode 100644 src/bin/logger/mod.rs diff --git a/src/bin/client.rs b/src/bin/client.rs index 13ab3c4..196186b 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -4,12 +4,28 @@ // 由 ling 创建于 2025/2/23. #![allow(non_snake_case)] +mod logger; + +use log::{error, info}; use openssl::x509::X509; use LingTransmit::client::Client; +use LingTransmit::packet::NetworkPackets; +use crate::logger::init_log; #[tokio::main] async fn main() { + init_log(); let cert = include_bytes!("../../assert/ZHSSCA.crt"); let cert = X509::from_pem(cert).unwrap(); let client = Client::tcp_connect("localhost", 11451, cert).await.unwrap(); + + client.send_buffer("你好,世界!".as_bytes()).await.unwrap(); + + let packet = client.read_packet().await.unwrap(); + if let NetworkPackets::UserAsk(data) = packet { + let str = String::from_utf8_lossy(&*data).to_string(); + info!("{}", str); + } else { + client.execute_protocol_packet(packet).await.unwrap(); + } } diff --git a/src/bin/logger/mod.rs b/src/bin/logger/mod.rs new file mode 100644 index 0000000..caa5ab6 --- /dev/null +++ b/src/bin/logger/mod.rs @@ -0,0 +1,36 @@ +// 版权所有 (c) ling 保留所有权利。 +// 除非另行说明,否则仅允许在LingTransmit中使用此文件中的代码。 +// +// 由 ling 创建于 2025/2/23. +#![allow(non_snake_case)] + +use chrono::Local; +use colored::{Color, Colorize}; +use fern::Dispatch; +use log::{Level, LevelFilter}; + +pub fn init_log() { + let console_dispatch = Dispatch::new() + .format(|out, message, record| { + let (title, color) = match record.level() { + Level::Error => ("Error", Color::Red), + Level::Warn => ("Warning", Color::Yellow), + Level::Info => ("Info", Color::Green), + Level::Debug => ("Debug", Color::BrightWhite), + Level::Trace => ("Trace", Color::White), + }; + + out.finish(format_args!( + "{}", + format!("[{} {}]\t{}", get_time(), title, message).color(color) + )) + }) + .chain(std::io::stdout()) + .level(LevelFilter::Trace) + .apply(); +} + +fn get_time() -> String { + let now = Local::now(); + now.format("%Y-%m-%d %H:%M:%S").to_string() +} \ No newline at end of file diff --git a/src/bin/main.rs b/src/bin/main.rs index d18b4d3..80d09f0 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -3,19 +3,21 @@ // // 由 ling 创建于 2025/1/18. #![allow(non_snake_case)] +mod logger; use async_trait::async_trait; use chrono::Local; use colored::{Color, Colorize}; use fern::Dispatch; use log::{debug, info, Level, LevelFilter}; -use std::sync::{Arc}; +use std::sync::Arc; use tokio::task::JoinHandle; use LingTransmit::server::event::ServerEvent; use LingTransmit::server::Client::Client; use LingTransmit::server::Server; use LingTransmit::shell::{register_command, start_shell, CommandActuators}; use LingTransmit::ssl::ServerCert; +use crate::logger::init_log; #[tokio::main] async fn main() { @@ -47,30 +49,6 @@ fn start_server(server: Arc) -> JoinHandle<()> { }) } -fn get_time() -> String { - let now = Local::now(); - now.format("%Y-%m-%d %H:%M:%S").to_string() -} -fn init_log() { - let console_dispatch = Dispatch::new() - .format(|out, message, record| { - let (title, color) = match record.level() { - Level::Error => ("Error", Color::Red), - Level::Warn => ("Warning", Color::Yellow), - Level::Info => ("Info", Color::Green), - Level::Debug => ("Debug", Color::BrightWhite), - Level::Trace => ("Trace", Color::White), - }; - - out.finish(format_args!( - "{}", - format!("[{} {}]\t{}", get_time(), title, message).color(color) - )) - }) - .chain(std::io::stdout()) - .level(LevelFilter::Trace) - .apply(); -} struct ExitCommand { server: Arc, @@ -103,11 +81,16 @@ impl ServerEvent for Event { } async fn client_user_data(&self, client: Arc, packet: Vec) -> std::io::Result<()> { + let str = String::from_utf8_lossy(&packet).to_string(); debug!( - "客户端发送数据,ID:{},数据长度:{}", + "客户端发送数据,ID:{},数据长度:{},数据内容:{}", client.id, - packet.len() + packet.len(), + str, ); + client + .send_buffer(format!("收到了:{}", str).as_bytes()) + .await?; Ok(()) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 4f28d50..9ff85a0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -5,12 +5,14 @@ #![allow(non_snake_case)] use crate::packet::code::*; +use crate::packet::NetworkPackets; use crate::stream::{OwnedReadHalfAbstraction, OwnedWriteHalfAbstraction}; use chrono::{DateTime, NaiveDateTime, Utc}; use log::trace; use openssl::asn1::Asn1Time; use openssl::pkey::Public; use openssl::rsa::{Padding, Rsa}; +use openssl::symm::{decrypt, encrypt, Cipher}; use openssl::x509::{X509NameEntryRef, X509}; use rand::Rng; use std::io; @@ -27,7 +29,7 @@ pub type ClientWrite = Arc>; pub struct Client { read: ClientRead, write: ClientWrite, - key: String, + key: Option, } /// 生成会话密钥 @@ -44,7 +46,7 @@ fn generate_key() -> String { } impl Client { - fn init(read: ClientRead, write: ClientWrite, key: String) -> Self { + fn init(read: ClientRead, write: ClientWrite, key: Option) -> Self { Client { read, write, key } } @@ -72,7 +74,11 @@ impl Client { })?; let rsa = public.rsa()?; Self::push_aes_key(read.clone(), write.clone(), &key, &rsa).await?; - Ok(Client { read, write, key }) + Ok(Client { + read, + write, + key: Some(key), + }) } /// 推送会话密钥 @@ -82,9 +88,13 @@ impl Client { key: &String, rsa: &Rsa, ) -> io::Result<()> { + let mut buffer_key = Vec::new(); + buffer_key.push('\n' as u8); + buffer_key.extend_from_slice(&Self::get_ling_transmit_version().to_le_bytes()); + buffer_key.extend_from_slice(key.as_bytes()); //使用服务器公钥加密会话密钥,然后发送给服务器 let mut buffer = vec![0u8; rsa.size() as usize]; - let size = rsa.public_encrypt(key.as_bytes(), &mut buffer, Padding::PKCS1)?; + let size = rsa.public_encrypt(&buffer_key, &mut buffer, Padding::PKCS1)?; let mut read = read.lock().await; let mut write = write.lock().await; @@ -127,6 +137,101 @@ impl Client { Ok(buffer) } + + /// 处理用于维护协议的数据包 + pub async fn execute_protocol_packet(&self, packet: NetworkPackets) -> io::Result<()> { + //协议暂时没有提出要求 + Ok(()) + } + + /// 按照协议读取一个用户数据包 + pub async fn read_packet(&self) -> io::Result { + let mut read = self.read.lock().await; + let start = read.read_i32_le().await?; + if start != LING_START { + return Err(io::Error::new( + io::ErrorKind::NetworkDown, + format!( + "协议被破坏:数据包起始标记错误,预期{:0x},发现{:0x}", + LING_START, start + ), + )); + } + let size = read.read_i32_le().await?; + let data_type = match read.read_i32_le().await? { + API_TYPE_ASK => API_TYPE_ASK, + num => { + return Err(io::Error::new( + io::ErrorKind::NetworkDown, + format!("协议被破坏:无效的数据包类型:{}", num), + )); + } + }; + let mut buffer = Vec::new(); + buffer.resize(size as usize, 0u8); + let read_size = read.read_exact(&mut buffer).await?; + if read_size != size as usize { + return Err(io::Error::new( + io::ErrorKind::NetworkDown, + format!("残缺数据包:读取到 {} 字节,预期 {} 字节", read_size, size), + )); + } + let stop = read.read_i32_le().await?; + if stop != LING_STOP { + return Err(io::Error::new( + io::ErrorKind::NetworkDown, + format!( + "协议被破坏:数据包结束标记错误,预期{:0x},发现{:0x}", + LING_STOP, stop + ), + )); + } + match data_type { + API_TYPE_ASK => { + if let Some(key) = &self.key { + let data = decrypt( + Cipher::aes_256_cbc(), + &key.as_bytes()[..32], + Some(&key.as_bytes()[..16]), + &buffer, + )?; + Ok(NetworkPackets::UserAsk(data)) + } else { + Ok(NetworkPackets::UserAsk(buffer)) + } + } + _ => Err(io::Error::new( + io::ErrorKind::NetworkDown, + "未知的数据包类型", + )), + } + } + + /// 按协议发送一个用户数据包 + pub async fn send_buffer(&self, buffer: &[u8]) -> io::Result<()> { + let data = if let Some(key) = &self.key { + encrypt( + Cipher::aes_256_cbc(), + &key.as_bytes()[..32], + Some(&key.as_bytes()[..16]), + buffer, + )? + } else { + buffer.to_vec() + }; + let mut write = self.write.lock().await; + write.write_i32_le(LING_START).await?; + write.write_i32_le(data.len() as i32).await?; + write.write_i32_le(API_TYPE_ASK).await?; + write.write_all(&data).await?; + write.write_i32_le(LING_STOP).await?; + Ok(()) + } + + /// 获取当前客户端使用的协议版本号 + pub fn get_ling_transmit_version() -> i32 { + LING_V1_1 + } } fn check_certificate(host: String, cert: X509) -> io::Result<()> { diff --git a/src/packet/code.rs b/src/packet/code.rs index cdcb518..4bf2c8b 100644 --- a/src/packet/code.rs +++ b/src/packet/code.rs @@ -1,6 +1,6 @@ // 版权所有 (c) ling 保留所有权利。 // 除非另行说明,否则仅允许在LingTransmit中使用此文件中的代码。 -// +// // 由 ling 创建于 2025/1/18. #![allow(non_snake_case)] @@ -15,6 +15,12 @@ pub const MAX_SIZE_V1_0: i32 = 32 * 1024; /// Ling - 1.1 数据包最大长度 pub const MAX_SIZE_V1_1: i32 = 512 * 1024; +/// Ling - 1.0 +pub const LING_V1_0: i32 = 0; +pub const LING_V1_1: i32 = 1; +/// 最大有效版本号 +pub const LING_VERSION_MAX: i32 = LING_V1_1; + /// 服务器确认 pub const SERVER_ACK: i32 = 99999999; /// 服务器错误 @@ -22,5 +28,5 @@ pub const SERVER_ERROR: i32 = 100000000; /// 用户接口 pub const API_TYPE_ASK: i32 = 1; -/// 交换密钥 -pub const API_TYPE_PUSH_AES_KEY: i32 = 4; \ No newline at end of file +/// 交换密钥,并设定版本号 +pub const API_TYPE_PUSH_AES_KEY: i32 = 4; diff --git a/src/packet/mod.rs b/src/packet/mod.rs index b843be7..299af4f 100644 --- a/src/packet/mod.rs +++ b/src/packet/mod.rs @@ -8,12 +8,12 @@ pub mod code; use crate::packet::code::*; use crate::server::Client::Client; +use crate::stream::OwnedReadHalfAbstraction; use std::fmt::format; use std::sync::atomic::Ordering; use std::sync::Arc; use std::{error, io}; use tokio::io::AsyncReadExt; -use crate::stream::OwnedReadHalfAbstraction; /// 数据包 pub enum NetworkPackets { diff --git a/src/server/Client.rs b/src/server/Client.rs index 5aeec03..4fe30bc 100644 --- a/src/server/Client.rs +++ b/src/server/Client.rs @@ -5,7 +5,9 @@ #![allow(non_snake_case)] use crate::close_sender::CloseSender; -use crate::packet::code::{SERVER_ACK, SERVER_ERROR}; +use crate::packet::code::{ + API_TYPE_ASK, LING_START, LING_STOP, LING_V1_1, LING_VERSION_MAX, SERVER_ACK, SERVER_ERROR, +}; use crate::packet::{read_packet, NetworkPackets}; use crate::server::accept::SocketAddr; use crate::server::event::ServerEvent; @@ -14,9 +16,10 @@ use crate::ssl::ServerCert; use crate::stream::{OwnedReadHalfAbstraction, OwnedWriteHalfAbstraction}; use log::{error, info}; use openssl::rsa::Padding; +use openssl::symm::{decrypt, encrypt, Cipher}; use std::io; use std::string::FromUtf8Error; -use std::sync::atomic::{AtomicBool, AtomicI32}; +use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; use std::sync::{Arc, OnceLock}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -137,11 +140,52 @@ impl Client { } } + /// 发送数据,自动将 buffer 打包为一个数据包并发送 + pub async fn send_buffer(&self, data: &[u8]) -> io::Result<()> { + let mut write = self.write_soc.lock().await; + // 仅在 Ling Transmit V 1.1及以上才需要携带包开始标记 + if self.syn_version.load(Ordering::Acquire) >= LING_V1_1 { + write.write_i32_le(LING_START).await?; + } + if let Some(key) = self.key.get() { + //传输加密后的包体 + let buffer = encrypt( + Cipher::aes_256_cbc(), + &key.as_bytes()[..32], + Some(&key.as_bytes()[..16]), + &*data, + )?; + write.write_i32_le(buffer.len() as i32).await?; + write.write_i32_le(API_TYPE_ASK).await?; + write.write_all(&buffer).await?; + } else { + write.write_i32_le(data.len() as i32).await?; + write.write_i32_le(API_TYPE_ASK).await?; + write.write_all(data).await?; + } + write.write_i32_le(LING_STOP).await?; + Ok(()) + } + /// 处理客户端请求 async fn process_packet(self: &Arc, packet: NetworkPackets) -> io::Result<()> { match packet { NetworkPackets::SynV1 => self.syn_v1().await, - NetworkPackets::UserAsk(data) => self.event.client_user_data(self.clone(), data).await, + NetworkPackets::UserAsk(data) => { + if let Some(key) = self.key.get() { + //使用会话密钥解密包体 + let data = decrypt( + Cipher::aes_256_cbc(), + &key.as_bytes()[..32], + Some(&key.as_bytes()[..16]), + &*data, + )?; + self.event.client_user_data(self.clone(), data).await + } else { + //在使用Unix Domain Socket时无需加密数据 + self.event.client_user_data(self.clone(), data).await + } + } NetworkPackets::PushAesKey(data) => self.client_push_key(&data).await, } } @@ -152,22 +196,26 @@ impl Client { let mut data = Vec::new(); data.resize(rsa.size() as usize, 0u8); - //解密会话密钥 rsa.private_decrypt(buff, &mut data, Padding::PKCS1)?; - let key = match String::from_utf8(data) { - Ok(key) => key, - Err(_) => { - return Err(io::Error::new(io::ErrorKind::NotFound, "解密会话密钥失败")); - } - }; + let (key, version) = parse_session_key(&data)?; let mut send = self.write_soc.lock().await; + //检查版本号是否合法 + if version > LING_VERSION_MAX || version < 0 { + let message = format!("无效的版本号:{}", version); + send.write_i32_le(SERVER_ERROR).await?; + send.write_i32_le(message.len() as i32).await?; + send.write_all(message.as_bytes()).await?; + return Ok(()); + } + info!("协议版本号:{}", version); if let Err(_) = self.key.set(key.clone()) { //重复推送,拒绝密钥 send.write_i32_le(SERVER_ERROR).await?; return Ok(()); } + self.syn_version.store(version, Ordering::Relaxed); send.write_i32_le(SERVER_ACK).await?; Ok(()) } @@ -181,3 +229,42 @@ impl Client { Ok(()) } } + +/// 从PushAesKey中提取出会话密钥和协议版本 +fn parse_session_key(data: &Vec) -> io::Result<(String, i32)> { + if data.len() < 5 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "数据量不足!")); + } + // 在Ling Transmit V1.1及以上版本中,如果PushAesKey数据以\n作为开始,则接下来四个字节将解析为协议版本号 + // 如果不以\n作为开始,则说明使用 Ling Transmit V1.0协议 + if data[0] != '\n' as u8 { + let key = match String::from_utf8(data.clone()) { + Ok(key) => key, + Err(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "解密会话密钥失败", + )); + } + }; + return Ok((key, 0)); + } + + let slice = &data[1..5]; + let slice: [u8; 4] = slice[..4] + .try_into() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("转换失败:{:?}", e)))?; + let version = i32::from_le_bytes(slice); + + let buffer = &data[5..]; + let key = match String::from_utf8(buffer.to_vec()) { + Ok(key) => key, + Err(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "解析会话密钥失败", + )); + } + }; + Ok((key, version)) +} diff --git a/src/stream.rs b/src/stream.rs index a03ccbe..2da1d42 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,6 +1,6 @@ // 版权所有 (c) ling 保留所有权利。 // 除非另行说明,否则仅允许在LingTransmit中使用此文件中的代码。 -// +// // 由 ling 创建于 2025/1/19. #![allow(non_snake_case)]