1
0
Fork 0
mirror of https://gitlab.com/futo-org/fcast.git synced 2025-06-24 21:25:23 +00:00

rs-terminal: buffer data read from websocket when underflowing

This commit is contained in:
Marcus Hanestad 2025-06-04 10:06:19 +02:00
parent 6529d91eb2
commit 3a7b7675ba
2 changed files with 94 additions and 17 deletions

View file

@ -1,4 +1,5 @@
use clap::{App, Arg, SubCommand};
use fcast::transport::WebSocket;
use std::collections::HashMap;
use std::net::IpAddr;
use std::str::FromStr;
@ -204,6 +205,7 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
MaybeTlsStream::Plain(ref stream) => Some(stream.local_addr()?.ip()),
_ => return Err("Established connection type is not plain.".into()),
};
let stream = WebSocket::new(stream);
FCastSession::connect(stream)?
}
_ => return Err("Invalid connection type.".into()),

View file

@ -1,6 +1,7 @@
use std::collections::VecDeque;
use std::io::{Read, Write};
use std::net::TcpStream;
use tungstenite::protocol::WebSocket;
use tungstenite::protocol::WebSocket as TWebSocket;
use tungstenite::Message;
pub trait Transport {
@ -28,28 +29,69 @@ impl Transport for TcpStream {
}
}
impl<T: Read + Write> Transport for WebSocket<T> {
fn transport_read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self.read() {
Ok(Message::Binary(data)) => {
let len = std::cmp::min(buf.len(), data.len());
buf[..len].copy_from_slice(&data[..len]);
Ok(len)
}
_ => Err(std::io::Error::other("Invalid message type")),
pub struct WebSocket<T>
where
T: Read + Write,
{
inner: TWebSocket<T>,
buffer: VecDeque<u8>,
}
impl<T> WebSocket<T>
where
T: Read + Write,
{
pub fn new(web_socket: TWebSocket<T>) -> Self {
Self {
inner: web_socket,
buffer: VecDeque::new(),
}
}
pub fn read_buffered(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
if !self.buffer.is_empty() {
let bytes_to_read = buf.len().min(self.buffer.len());
assert!(buf.len() >= bytes_to_read);
assert!(self.buffer.len() >= bytes_to_read);
for i in 0..bytes_to_read {
buf[i] = self.buffer.pop_front().unwrap(); // Safe unwrap as bounds was checked previously
}
} else {
match self.inner.read() {
Ok(Message::Binary(data)) => {
let bytes_to_read = buf.len().min(data.len());
buf.copy_from_slice(&data[..bytes_to_read]);
for rest in data[bytes_to_read..].iter() {
self.buffer.push_back(*rest);
}
}
_ => return Err(std::io::Error::other("Invalid message type")),
}
}
Ok(buf.len())
}
}
impl<T> Transport for WebSocket<T>
where
T: Read + Write,
{
fn transport_read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
self.read_buffered(buf)
}
fn transport_write(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
self.write(Message::Binary(buf.to_vec()))
self.inner
.write(Message::Binary(buf.to_vec()))
.map_err(std::io::Error::other)?;
self.flush().map_err(std::io::Error::other)
self.inner.flush().map_err(std::io::Error::other)
}
fn transport_shutdown(&mut self) -> Result<(), std::io::Error> {
self.close(None).map_err(std::io::Error::other)?;
self.inner.close(None).map_err(std::io::Error::other)?;
loop {
match self.read() {
match self.inner.read() {
Ok(_) => continue,
Err(tungstenite::Error::ConnectionClosed) => break,
Err(e) => return Err(std::io::Error::other(e)),
@ -62,7 +104,7 @@ impl<T: Read + Write> Transport for WebSocket<T> {
fn transport_read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> {
let mut total_read = 0;
while total_read < buf.len() {
total_read += self.transport_read(&mut buf[total_read..])?;
total_read += self.read_buffered(&mut buf[total_read..])?;
}
Ok(())
@ -75,6 +117,38 @@ mod tests {
use super::*;
#[test]
fn websocket_read_buffered() {
let jh = std::thread::spawn(|| {
let server = TcpListener::bind("127.0.0.1:51232").unwrap();
let stream = server.incoming().next().unwrap().unwrap();
let mut websocket = tungstenite::accept(stream).unwrap();
websocket
.send(tungstenite::Message::binary([1, 2, 3, 4]))
.unwrap();
websocket
.send(tungstenite::Message::binary([5, 6, 7, 8]))
.unwrap();
});
let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51232").unwrap();
let mut websocket = WebSocket::new(websocket);
let mut buf = [0u8; 2];
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2);
assert_eq!(buf, [1, 2]);
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2);
assert_eq!(buf, [3, 4]);
let mut buf = [0u8; 4];
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 4);
assert_eq!(buf, [5, 6, 7, 8]);
let _ = websocket.transport_shutdown();
jh.join().unwrap();
}
#[test]
fn websocket_read_exact() {
let jh = std::thread::spawn(|| {
@ -86,7 +160,8 @@ mod tests {
.unwrap();
});
let (mut websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap();
let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap();
let mut websocket = WebSocket::new(websocket);
fn read_exact<T: Transport>(stream: &mut T) {
let mut buf = [0u8; 3];
@ -96,7 +171,7 @@ mod tests {
read_exact(&mut websocket);
websocket.close(None).unwrap();
let _ = websocket.transport_shutdown();
jh.join().unwrap();
}