实现传输过程

This commit is contained in:
2025-02-23 22:15:44 +08:00
parent 12fd63604a
commit 67f9114767
8 changed files with 279 additions and 46 deletions

View File

@@ -4,12 +4,28 @@
// 由 ling 创建于 2025/2/23.
#![allow(non_snake_case)]
mod logger;
use log::{error, info};
use openssl::x509::X509;
use LingTransmit::client::Client;
use LingTransmit::packet::NetworkPackets;
use crate::logger::init_log;
#[tokio::main]
async fn main() {
init_log();
let cert = include_bytes!("../../assert/ZHSSCA.crt");
let cert = X509::from_pem(cert).unwrap();
let client = Client::tcp_connect("localhost", 11451, cert).await.unwrap();
client.send_buffer("你好,世界!".as_bytes()).await.unwrap();
let packet = client.read_packet().await.unwrap();
if let NetworkPackets::UserAsk(data) = packet {
let str = String::from_utf8_lossy(&*data).to_string();
info!("{}", str);
} else {
client.execute_protocol_packet(packet).await.unwrap();
}
}

36
src/bin/logger/mod.rs Normal file
View File

@@ -0,0 +1,36 @@
// 版权所有 (c) ling 保留所有权利。
// 除非另行说明否则仅允许在LingTransmit中使用此文件中的代码。
//
// 由 ling 创建于 2025/2/23.
#![allow(non_snake_case)]
use chrono::Local;
use colored::{Color, Colorize};
use fern::Dispatch;
use log::{Level, LevelFilter};
pub fn init_log() {
let console_dispatch = Dispatch::new()
.format(|out, message, record| {
let (title, color) = match record.level() {
Level::Error => ("Error", Color::Red),
Level::Warn => ("Warning", Color::Yellow),
Level::Info => ("Info", Color::Green),
Level::Debug => ("Debug", Color::BrightWhite),
Level::Trace => ("Trace", Color::White),
};
out.finish(format_args!(
"{}",
format!("[{} {}]\t{}", get_time(), title, message).color(color)
))
})
.chain(std::io::stdout())
.level(LevelFilter::Trace)
.apply();
}
fn get_time() -> String {
let now = Local::now();
now.format("%Y-%m-%d %H:%M:%S").to_string()
}

View File

@@ -3,19 +3,21 @@
//
// 由 ling 创建于 2025/1/18.
#![allow(non_snake_case)]
mod logger;
use async_trait::async_trait;
use chrono::Local;
use colored::{Color, Colorize};
use fern::Dispatch;
use log::{debug, info, Level, LevelFilter};
use std::sync::{Arc};
use std::sync::Arc;
use tokio::task::JoinHandle;
use LingTransmit::server::event::ServerEvent;
use LingTransmit::server::Client::Client;
use LingTransmit::server::Server;
use LingTransmit::shell::{register_command, start_shell, CommandActuators};
use LingTransmit::ssl::ServerCert;
use crate::logger::init_log;
#[tokio::main]
async fn main() {
@@ -47,30 +49,6 @@ fn start_server(server: Arc<Server>) -> JoinHandle<()> {
})
}
fn get_time() -> String {
let now = Local::now();
now.format("%Y-%m-%d %H:%M:%S").to_string()
}
fn init_log() {
let console_dispatch = Dispatch::new()
.format(|out, message, record| {
let (title, color) = match record.level() {
Level::Error => ("Error", Color::Red),
Level::Warn => ("Warning", Color::Yellow),
Level::Info => ("Info", Color::Green),
Level::Debug => ("Debug", Color::BrightWhite),
Level::Trace => ("Trace", Color::White),
};
out.finish(format_args!(
"{}",
format!("[{} {}]\t{}", get_time(), title, message).color(color)
))
})
.chain(std::io::stdout())
.level(LevelFilter::Trace)
.apply();
}
struct ExitCommand {
server: Arc<Server>,
@@ -103,11 +81,16 @@ impl ServerEvent for Event {
}
async fn client_user_data(&self, client: Arc<Client>, packet: Vec<u8>) -> std::io::Result<()> {
let str = String::from_utf8_lossy(&packet).to_string();
debug!(
"客户端发送数据ID{},数据长度:{}",
"客户端发送数据ID{},数据长度:{},数据内容:{}",
client.id,
packet.len()
packet.len(),
str,
);
client
.send_buffer(format!("收到了:{}", str).as_bytes())
.await?;
Ok(())
}
}

View File

@@ -5,12 +5,14 @@
#![allow(non_snake_case)]
use crate::packet::code::*;
use crate::packet::NetworkPackets;
use crate::stream::{OwnedReadHalfAbstraction, OwnedWriteHalfAbstraction};
use chrono::{DateTime, NaiveDateTime, Utc};
use log::trace;
use openssl::asn1::Asn1Time;
use openssl::pkey::Public;
use openssl::rsa::{Padding, Rsa};
use openssl::symm::{decrypt, encrypt, Cipher};
use openssl::x509::{X509NameEntryRef, X509};
use rand::Rng;
use std::io;
@@ -27,7 +29,7 @@ pub type ClientWrite = Arc<Mutex<dyn OwnedWriteHalfAbstraction>>;
pub struct Client {
read: ClientRead,
write: ClientWrite,
key: String,
key: Option<String>,
}
/// 生成会话密钥
@@ -44,7 +46,7 @@ fn generate_key() -> String {
}
impl Client {
fn init(read: ClientRead, write: ClientWrite, key: String) -> Self {
fn init(read: ClientRead, write: ClientWrite, key: Option<String>) -> Self {
Client { read, write, key }
}
@@ -72,7 +74,11 @@ impl Client {
})?;
let rsa = public.rsa()?;
Self::push_aes_key(read.clone(), write.clone(), &key, &rsa).await?;
Ok(Client { read, write, key })
Ok(Client {
read,
write,
key: Some(key),
})
}
/// 推送会话密钥
@@ -82,9 +88,13 @@ impl Client {
key: &String,
rsa: &Rsa<Public>,
) -> io::Result<()> {
let mut buffer_key = Vec::new();
buffer_key.push('\n' as u8);
buffer_key.extend_from_slice(&Self::get_ling_transmit_version().to_le_bytes());
buffer_key.extend_from_slice(key.as_bytes());
//使用服务器公钥加密会话密钥,然后发送给服务器
let mut buffer = vec![0u8; rsa.size() as usize];
let size = rsa.public_encrypt(key.as_bytes(), &mut buffer, Padding::PKCS1)?;
let size = rsa.public_encrypt(&buffer_key, &mut buffer, Padding::PKCS1)?;
let mut read = read.lock().await;
let mut write = write.lock().await;
@@ -127,6 +137,101 @@ impl Client {
Ok(buffer)
}
/// 处理用于维护协议的数据包
pub async fn execute_protocol_packet(&self, packet: NetworkPackets) -> io::Result<()> {
//协议暂时没有提出要求
Ok(())
}
/// 按照协议读取一个用户数据包
pub async fn read_packet(&self) -> io::Result<NetworkPackets> {
let mut read = self.read.lock().await;
let start = read.read_i32_le().await?;
if start != LING_START {
return Err(io::Error::new(
io::ErrorKind::NetworkDown,
format!(
"协议被破坏:数据包起始标记错误,预期{:0x},发现{:0x}",
LING_START, start
),
));
}
let size = read.read_i32_le().await?;
let data_type = match read.read_i32_le().await? {
API_TYPE_ASK => API_TYPE_ASK,
num => {
return Err(io::Error::new(
io::ErrorKind::NetworkDown,
format!("协议被破坏:无效的数据包类型:{}", num),
));
}
};
let mut buffer = Vec::new();
buffer.resize(size as usize, 0u8);
let read_size = read.read_exact(&mut buffer).await?;
if read_size != size as usize {
return Err(io::Error::new(
io::ErrorKind::NetworkDown,
format!("残缺数据包:读取到 {} 字节,预期 {} 字节", read_size, size),
));
}
let stop = read.read_i32_le().await?;
if stop != LING_STOP {
return Err(io::Error::new(
io::ErrorKind::NetworkDown,
format!(
"协议被破坏:数据包结束标记错误,预期{:0x},发现{:0x}",
LING_STOP, stop
),
));
}
match data_type {
API_TYPE_ASK => {
if let Some(key) = &self.key {
let data = decrypt(
Cipher::aes_256_cbc(),
&key.as_bytes()[..32],
Some(&key.as_bytes()[..16]),
&buffer,
)?;
Ok(NetworkPackets::UserAsk(data))
} else {
Ok(NetworkPackets::UserAsk(buffer))
}
}
_ => Err(io::Error::new(
io::ErrorKind::NetworkDown,
"未知的数据包类型",
)),
}
}
/// 按协议发送一个用户数据包
pub async fn send_buffer(&self, buffer: &[u8]) -> io::Result<()> {
let data = if let Some(key) = &self.key {
encrypt(
Cipher::aes_256_cbc(),
&key.as_bytes()[..32],
Some(&key.as_bytes()[..16]),
buffer,
)?
} else {
buffer.to_vec()
};
let mut write = self.write.lock().await;
write.write_i32_le(LING_START).await?;
write.write_i32_le(data.len() as i32).await?;
write.write_i32_le(API_TYPE_ASK).await?;
write.write_all(&data).await?;
write.write_i32_le(LING_STOP).await?;
Ok(())
}
/// 获取当前客户端使用的协议版本号
pub fn get_ling_transmit_version() -> i32 {
LING_V1_1
}
}
fn check_certificate(host: String, cert: X509) -> io::Result<()> {

View File

@@ -1,6 +1,6 @@
// 版权所有 (c) ling 保留所有权利。
// 除非另行说明否则仅允许在LingTransmit中使用此文件中的代码。
//
//
// 由 ling 创建于 2025/1/18.
#![allow(non_snake_case)]
@@ -15,6 +15,12 @@ pub const MAX_SIZE_V1_0: i32 = 32 * 1024;
/// Ling - 1.1 数据包最大长度
pub const MAX_SIZE_V1_1: i32 = 512 * 1024;
/// Ling - 1.0
pub const LING_V1_0: i32 = 0;
pub const LING_V1_1: i32 = 1;
/// 最大有效版本号
pub const LING_VERSION_MAX: i32 = LING_V1_1;
/// 服务器确认
pub const SERVER_ACK: i32 = 99999999;
/// 服务器错误
@@ -22,5 +28,5 @@ pub const SERVER_ERROR: i32 = 100000000;
/// 用户接口
pub const API_TYPE_ASK: i32 = 1;
/// 交换密钥
pub const API_TYPE_PUSH_AES_KEY: i32 = 4;
/// 交换密钥,并设定版本号
pub const API_TYPE_PUSH_AES_KEY: i32 = 4;

View File

@@ -8,12 +8,12 @@ pub mod code;
use crate::packet::code::*;
use crate::server::Client::Client;
use crate::stream::OwnedReadHalfAbstraction;
use std::fmt::format;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::{error, io};
use tokio::io::AsyncReadExt;
use crate::stream::OwnedReadHalfAbstraction;
/// 数据包
pub enum NetworkPackets {

View File

@@ -5,7 +5,9 @@
#![allow(non_snake_case)]
use crate::close_sender::CloseSender;
use crate::packet::code::{SERVER_ACK, SERVER_ERROR};
use crate::packet::code::{
API_TYPE_ASK, LING_START, LING_STOP, LING_V1_1, LING_VERSION_MAX, SERVER_ACK, SERVER_ERROR,
};
use crate::packet::{read_packet, NetworkPackets};
use crate::server::accept::SocketAddr;
use crate::server::event::ServerEvent;
@@ -14,9 +16,10 @@ use crate::ssl::ServerCert;
use crate::stream::{OwnedReadHalfAbstraction, OwnedWriteHalfAbstraction};
use log::{error, info};
use openssl::rsa::Padding;
use openssl::symm::{decrypt, encrypt, Cipher};
use std::io;
use std::string::FromUtf8Error;
use std::sync::atomic::{AtomicBool, AtomicI32};
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -137,11 +140,52 @@ impl Client {
}
}
/// 发送数据,自动将 buffer 打包为一个数据包并发送
pub async fn send_buffer(&self, data: &[u8]) -> io::Result<()> {
let mut write = self.write_soc.lock().await;
// 仅在 Ling Transmit V 1.1及以上才需要携带包开始标记
if self.syn_version.load(Ordering::Acquire) >= LING_V1_1 {
write.write_i32_le(LING_START).await?;
}
if let Some(key) = self.key.get() {
//传输加密后的包体
let buffer = encrypt(
Cipher::aes_256_cbc(),
&key.as_bytes()[..32],
Some(&key.as_bytes()[..16]),
&*data,
)?;
write.write_i32_le(buffer.len() as i32).await?;
write.write_i32_le(API_TYPE_ASK).await?;
write.write_all(&buffer).await?;
} else {
write.write_i32_le(data.len() as i32).await?;
write.write_i32_le(API_TYPE_ASK).await?;
write.write_all(data).await?;
}
write.write_i32_le(LING_STOP).await?;
Ok(())
}
/// 处理客户端请求
async fn process_packet(self: &Arc<Client>, packet: NetworkPackets) -> io::Result<()> {
match packet {
NetworkPackets::SynV1 => self.syn_v1().await,
NetworkPackets::UserAsk(data) => self.event.client_user_data(self.clone(), data).await,
NetworkPackets::UserAsk(data) => {
if let Some(key) = self.key.get() {
//使用会话密钥解密包体
let data = decrypt(
Cipher::aes_256_cbc(),
&key.as_bytes()[..32],
Some(&key.as_bytes()[..16]),
&*data,
)?;
self.event.client_user_data(self.clone(), data).await
} else {
//在使用Unix Domain Socket时无需加密数据
self.event.client_user_data(self.clone(), data).await
}
}
NetworkPackets::PushAesKey(data) => self.client_push_key(&data).await,
}
}
@@ -152,22 +196,26 @@ impl Client {
let mut data = Vec::new();
data.resize(rsa.size() as usize, 0u8);
//解密会话密钥
rsa.private_decrypt(buff, &mut data, Padding::PKCS1)?;
let key = match String::from_utf8(data) {
Ok(key) => key,
Err(_) => {
return Err(io::Error::new(io::ErrorKind::NotFound, "解密会话密钥失败"));
}
};
let (key, version) = parse_session_key(&data)?;
let mut send = self.write_soc.lock().await;
//检查版本号是否合法
if version > LING_VERSION_MAX || version < 0 {
let message = format!("无效的版本号:{}", version);
send.write_i32_le(SERVER_ERROR).await?;
send.write_i32_le(message.len() as i32).await?;
send.write_all(message.as_bytes()).await?;
return Ok(());
}
info!("协议版本号:{}", version);
if let Err(_) = self.key.set(key.clone()) {
//重复推送,拒绝密钥
send.write_i32_le(SERVER_ERROR).await?;
return Ok(());
}
self.syn_version.store(version, Ordering::Relaxed);
send.write_i32_le(SERVER_ACK).await?;
Ok(())
}
@@ -181,3 +229,42 @@ impl Client {
Ok(())
}
}
/// 从PushAesKey中提取出会话密钥和协议版本
fn parse_session_key(data: &Vec<u8>) -> io::Result<(String, i32)> {
if data.len() < 5 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "数据量不足!"));
}
// 在Ling Transmit V1.1及以上版本中如果PushAesKey数据以\n作为开始则接下来四个字节将解析为协议版本号
// 如果不以\n作为开始则说明使用 Ling Transmit V1.0协议
if data[0] != '\n' as u8 {
let key = match String::from_utf8(data.clone()) {
Ok(key) => key,
Err(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"解密会话密钥失败",
));
}
};
return Ok((key, 0));
}
let slice = &data[1..5];
let slice: [u8; 4] = slice[..4]
.try_into()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("转换失败:{:?}", e)))?;
let version = i32::from_le_bytes(slice);
let buffer = &data[5..];
let key = match String::from_utf8(buffer.to_vec()) {
Ok(key) => key,
Err(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"解析会话密钥失败",
));
}
};
Ok((key, version))
}

View File

@@ -1,6 +1,6 @@
// 版权所有 (c) ling 保留所有权利。
// 除非另行说明否则仅允许在LingTransmit中使用此文件中的代码。
//
//
// 由 ling 创建于 2025/1/19.
#![allow(non_snake_case)]