server improvements

This commit is contained in:
Mateusz Faderewski 2024-07-02 23:06:42 +02:00
parent 3f639855d5
commit b915a7b4e3
2 changed files with 54 additions and 78 deletions

View File

@ -64,7 +64,7 @@ const SERIAL_PREFIX: &str = "serial://";
const FTDI_PREFIX: &str = "ftdi://"; const FTDI_PREFIX: &str = "ftdi://";
const RESET_TIMEOUT: Duration = Duration::from_secs(1); const RESET_TIMEOUT: Duration = Duration::from_secs(1);
const POLL_TIMEOUT: Duration = Duration::from_millis(10); const POLL_TIMEOUT: Duration = Duration::from_millis(5);
const READ_TIMEOUT: Duration = Duration::from_secs(5); const READ_TIMEOUT: Duration = Duration::from_secs(5);
const WRITE_TIMEOUT: Duration = Duration::from_secs(5); const WRITE_TIMEOUT: Duration = Duration::from_secs(5);

View File

@ -4,11 +4,7 @@ use super::{
list_local_devices, new_local, AsynchronousPacket, Command, DataType, Response, UsbPacket, list_local_devices, new_local, AsynchronousPacket, Command, DataType, Response, UsbPacket,
}, },
}; };
use std::{ use std::io::{Read, Write};
io::{Read, Write},
net::{TcpListener, TcpStream},
time::{Duration, Instant},
};
pub enum ServerEvent { pub enum ServerEvent {
Listening(String), Listening(String),
@ -18,60 +14,40 @@ pub enum ServerEvent {
} }
struct StreamHandler { struct StreamHandler {
stream: TcpStream, stream: std::net::TcpStream,
reader: std::io::BufReader<std::net::TcpStream>,
writer: std::io::BufWriter<std::net::TcpStream>,
} }
const POLL_TIMEOUT: Duration = Duration::from_millis(1); const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
const READ_TIMEOUT: Duration = Duration::from_secs(5); const WRITE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
const WRITE_TIMEOUT: Duration = Duration::from_secs(5); const KEEPALIVE_PERIOD: std::time::Duration = std::time::Duration::from_secs(5);
const KEEPALIVE_PERIOD: Duration = Duration::from_secs(5);
impl StreamHandler { impl StreamHandler {
fn new(stream: TcpStream) -> std::io::Result<StreamHandler> { fn new(stream: std::net::TcpStream) -> std::io::Result<StreamHandler> {
let reader = std::io::BufReader::new(stream.try_clone()?);
let writer = std::io::BufWriter::new(stream.try_clone()?);
stream.set_read_timeout(Some(READ_TIMEOUT))?; stream.set_read_timeout(Some(READ_TIMEOUT))?;
stream.set_write_timeout(Some(WRITE_TIMEOUT))?; stream.set_write_timeout(Some(WRITE_TIMEOUT))?;
Ok(StreamHandler { stream }) Ok(StreamHandler {
} stream,
reader,
fn try_read_exact(&mut self, buffer: &mut [u8]) -> std::io::Result<Option<()>> { writer,
let mut position = 0; })
let length = buffer.len();
let timeout = Instant::now();
self.stream.set_read_timeout(Some(POLL_TIMEOUT))?;
while position < length {
match self.stream.read(&mut buffer[position..length]) {
Ok(0) => return Err(std::io::ErrorKind::UnexpectedEof.into()),
Ok(bytes) => position += bytes,
Err(error) => match error.kind() {
std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut
| std::io::ErrorKind::WouldBlock => {
if position == 0 {
break;
}
}
_ => return Err(error),
},
}
if timeout.elapsed() > READ_TIMEOUT {
return Err(std::io::ErrorKind::TimedOut.into());
}
}
self.stream.set_read_timeout(Some(READ_TIMEOUT))?;
if position > 0 {
Ok(Some(()))
} else {
Ok(None)
}
} }
fn try_read_header(&mut self) -> std::io::Result<Option<[u8; 4]>> { fn try_read_header(&mut self) -> std::io::Result<Option<[u8; 4]>> {
self.stream.set_nonblocking(true)?;
let mut header = [0u8; 4]; let mut header = [0u8; 4];
Ok(self.try_read_exact(&mut header)?.map(|_| header)) let result = match self.reader.read_exact(&mut header) {
Ok(()) => Ok(Some(header)),
Err(error) => match error.kind() {
std::io::ErrorKind::WouldBlock => Ok(None),
_ => Err(error),
},
};
self.stream.set_nonblocking(false)?;
result
} }
fn receive_command(&mut self) -> std::io::Result<Option<Command>> { fn receive_command(&mut self) -> std::io::Result<Option<Command>> {
@ -88,18 +64,18 @@ impl StreamHandler {
let mut id_buffer = [0u8; 1]; let mut id_buffer = [0u8; 1];
let mut args = [0u32; 2]; let mut args = [0u32; 2];
self.stream.read_exact(&mut id_buffer)?; self.reader.read_exact(&mut id_buffer)?;
let id = id_buffer[0]; let id = id_buffer[0];
self.stream.read_exact(&mut buffer)?; self.reader.read_exact(&mut buffer)?;
args[0] = u32::from_be_bytes(buffer); args[0] = u32::from_be_bytes(buffer);
self.stream.read_exact(&mut buffer)?; self.reader.read_exact(&mut buffer)?;
args[1] = u32::from_be_bytes(buffer); args[1] = u32::from_be_bytes(buffer);
self.stream.read_exact(&mut buffer)?; self.reader.read_exact(&mut buffer)?;
let command_data_length = u32::from_be_bytes(buffer) as usize; let command_data_length = u32::from_be_bytes(buffer) as usize;
let mut data = vec![0u8; command_data_length]; let mut data = vec![0u8; command_data_length];
self.stream.read_exact(&mut data)?; self.reader.read_exact(&mut data)?;
Ok(Some(Command { id, args, data })) Ok(Some(Command { id, args, data }))
} else { } else {
@ -108,32 +84,32 @@ impl StreamHandler {
} }
fn send_response(&mut self, response: Response) -> std::io::Result<()> { fn send_response(&mut self, response: Response) -> std::io::Result<()> {
self.stream self.writer
.write_all(&u32::to_be_bytes(DataType::Response.into()))?; .write_all(&u32::to_be_bytes(DataType::Response.into()))?;
self.stream.write_all(&[response.id])?; self.writer.write_all(&[response.id])?;
self.stream.write_all(&[response.error as u8])?; self.writer.write_all(&[response.error as u8])?;
self.stream self.writer
.write_all(&(response.data.len() as u32).to_be_bytes())?; .write_all(&(response.data.len() as u32).to_be_bytes())?;
self.stream.write_all(&response.data)?; self.writer.write_all(&response.data)?;
self.stream.flush()?; self.writer.flush()?;
Ok(()) Ok(())
} }
fn send_packet(&mut self, packet: AsynchronousPacket) -> std::io::Result<()> { fn send_packet(&mut self, packet: AsynchronousPacket) -> std::io::Result<()> {
self.stream self.writer
.write_all(&u32::to_be_bytes(DataType::Packet.into()))?; .write_all(&u32::to_be_bytes(DataType::Packet.into()))?;
self.stream.write_all(&[packet.id])?; self.writer.write_all(&[packet.id])?;
self.stream self.writer
.write_all(&(packet.data.len() as u32).to_be_bytes())?; .write_all(&(packet.data.len() as u32).to_be_bytes())?;
self.stream.write_all(&packet.data)?; self.writer.write_all(&packet.data)?;
self.stream.flush()?; self.writer.flush()?;
Ok(()) Ok(())
} }
fn send_keepalive(&mut self) -> std::io::Result<()> { fn send_keepalive(&mut self) -> std::io::Result<()> {
self.stream self.writer
.write_all(&u32::to_be_bytes(DataType::KeepAlive.into()))?; .write_all(&u32::to_be_bytes(DataType::KeepAlive.into()))?;
self.stream.flush()?; self.writer.flush()?;
Ok(()) Ok(())
} }
} }
@ -141,16 +117,9 @@ impl StreamHandler {
fn server_accept_connection(port: String, connection: &mut StreamHandler) -> Result<(), Error> { fn server_accept_connection(port: String, connection: &mut StreamHandler) -> Result<(), Error> {
let mut link = new_local(&port)?; let mut link = new_local(&port)?;
let mut keepalive = Instant::now(); let mut keepalive = std::time::Instant::now();
loop { loop {
while let Some(usb_packet) = link.receive_response_or_packet()? {
match usb_packet {
UsbPacket::Response(response) => connection.send_response(response)?,
UsbPacket::AsynchronousPacket(packet) => connection.send_packet(packet)?,
}
}
match connection.receive_command() { match connection.receive_command() {
Ok(Some(command)) => { Ok(Some(command)) => {
link.execute_command_raw(&command, true, true)?; link.execute_command_raw(&command, true, true)?;
@ -162,8 +131,15 @@ fn server_accept_connection(port: String, connection: &mut StreamHandler) -> Res
}, },
}; };
if let Some(usb_packet) = link.receive_response_or_packet()? {
match usb_packet {
UsbPacket::Response(response) => connection.send_response(response)?,
UsbPacket::AsynchronousPacket(packet) => connection.send_packet(packet)?,
}
}
if keepalive.elapsed() > KEEPALIVE_PERIOD { if keepalive.elapsed() > KEEPALIVE_PERIOD {
keepalive = Instant::now(); keepalive = std::time::Instant::now();
connection.send_keepalive().ok(); connection.send_keepalive().ok();
} }
} }
@ -175,7 +151,7 @@ pub fn run(
event_callback: fn(ServerEvent), event_callback: fn(ServerEvent),
) -> Result<(), Error> { ) -> Result<(), Error> {
let port = port.unwrap_or(list_local_devices()?[0].port.clone()); let port = port.unwrap_or(list_local_devices()?[0].port.clone());
let listener = TcpListener::bind(address)?; let listener = std::net::TcpListener::bind(address)?;
let listening_address = listener.local_addr()?; let listening_address = listener.local_addr()?;
event_callback(ServerEvent::Listening(listening_address.to_string())); event_callback(ServerEvent::Listening(listening_address.to_string()));