1
0
Fork 0
mirror of https://gitlab.com/futo-org/fcast.git synced 2025-06-24 21:25:23 +00:00

rs-terminal: buffer data read from websocket when underflowing

This commit is contained in:
Marcus Hanestad 2025-06-04 10:06:19 +02:00
parent 6529d91eb2
commit 3a7b7675ba
2 changed files with 94 additions and 17 deletions

View file

@ -1,4 +1,5 @@
use clap::{App, Arg, SubCommand}; use clap::{App, Arg, SubCommand};
use fcast::transport::WebSocket;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr; use std::str::FromStr;
@ -204,6 +205,7 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
MaybeTlsStream::Plain(ref stream) => Some(stream.local_addr()?.ip()), MaybeTlsStream::Plain(ref stream) => Some(stream.local_addr()?.ip()),
_ => return Err("Established connection type is not plain.".into()), _ => return Err("Established connection type is not plain.".into()),
}; };
let stream = WebSocket::new(stream);
FCastSession::connect(stream)? FCastSession::connect(stream)?
} }
_ => return Err("Invalid connection type.".into()), _ => return Err("Invalid connection type.".into()),

View file

@ -1,6 +1,7 @@
use std::collections::VecDeque;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::TcpStream; use std::net::TcpStream;
use tungstenite::protocol::WebSocket; use tungstenite::protocol::WebSocket as TWebSocket;
use tungstenite::Message; use tungstenite::Message;
pub trait Transport { pub trait Transport {
@ -28,28 +29,69 @@ impl Transport for TcpStream {
} }
} }
impl<T: Read + Write> Transport for WebSocket<T> { pub struct WebSocket<T>
fn transport_read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> { where
match self.read() { T: Read + Write,
Ok(Message::Binary(data)) => { {
let len = std::cmp::min(buf.len(), data.len()); inner: TWebSocket<T>,
buf[..len].copy_from_slice(&data[..len]); buffer: VecDeque<u8>,
Ok(len) }
}
_ => Err(std::io::Error::other("Invalid message type")), impl<T> WebSocket<T>
where
T: Read + Write,
{
pub fn new(web_socket: TWebSocket<T>) -> Self {
Self {
inner: web_socket,
buffer: VecDeque::new(),
} }
} }
pub fn read_buffered(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
if !self.buffer.is_empty() {
let bytes_to_read = buf.len().min(self.buffer.len());
assert!(buf.len() >= bytes_to_read);
assert!(self.buffer.len() >= bytes_to_read);
for i in 0..bytes_to_read {
buf[i] = self.buffer.pop_front().unwrap(); // Safe unwrap as bounds was checked previously
}
} else {
match self.inner.read() {
Ok(Message::Binary(data)) => {
let bytes_to_read = buf.len().min(data.len());
buf.copy_from_slice(&data[..bytes_to_read]);
for rest in data[bytes_to_read..].iter() {
self.buffer.push_back(*rest);
}
}
_ => return Err(std::io::Error::other("Invalid message type")),
}
}
Ok(buf.len())
}
}
impl<T> Transport for WebSocket<T>
where
T: Read + Write,
{
fn transport_read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
self.read_buffered(buf)
}
fn transport_write(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { fn transport_write(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
self.write(Message::Binary(buf.to_vec())) self.inner
.write(Message::Binary(buf.to_vec()))
.map_err(std::io::Error::other)?; .map_err(std::io::Error::other)?;
self.flush().map_err(std::io::Error::other) self.inner.flush().map_err(std::io::Error::other)
} }
fn transport_shutdown(&mut self) -> Result<(), std::io::Error> { fn transport_shutdown(&mut self) -> Result<(), std::io::Error> {
self.close(None).map_err(std::io::Error::other)?; self.inner.close(None).map_err(std::io::Error::other)?;
loop { loop {
match self.read() { match self.inner.read() {
Ok(_) => continue, Ok(_) => continue,
Err(tungstenite::Error::ConnectionClosed) => break, Err(tungstenite::Error::ConnectionClosed) => break,
Err(e) => return Err(std::io::Error::other(e)), Err(e) => return Err(std::io::Error::other(e)),
@ -62,7 +104,7 @@ impl<T: Read + Write> Transport for WebSocket<T> {
fn transport_read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> { fn transport_read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> {
let mut total_read = 0; let mut total_read = 0;
while total_read < buf.len() { while total_read < buf.len() {
total_read += self.transport_read(&mut buf[total_read..])?; total_read += self.read_buffered(&mut buf[total_read..])?;
} }
Ok(()) Ok(())
@ -75,6 +117,38 @@ mod tests {
use super::*; use super::*;
#[test]
fn websocket_read_buffered() {
let jh = std::thread::spawn(|| {
let server = TcpListener::bind("127.0.0.1:51232").unwrap();
let stream = server.incoming().next().unwrap().unwrap();
let mut websocket = tungstenite::accept(stream).unwrap();
websocket
.send(tungstenite::Message::binary([1, 2, 3, 4]))
.unwrap();
websocket
.send(tungstenite::Message::binary([5, 6, 7, 8]))
.unwrap();
});
let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51232").unwrap();
let mut websocket = WebSocket::new(websocket);
let mut buf = [0u8; 2];
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2);
assert_eq!(buf, [1, 2]);
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2);
assert_eq!(buf, [3, 4]);
let mut buf = [0u8; 4];
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 4);
assert_eq!(buf, [5, 6, 7, 8]);
let _ = websocket.transport_shutdown();
jh.join().unwrap();
}
#[test] #[test]
fn websocket_read_exact() { fn websocket_read_exact() {
let jh = std::thread::spawn(|| { let jh = std::thread::spawn(|| {
@ -86,7 +160,8 @@ mod tests {
.unwrap(); .unwrap();
}); });
let (mut websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap(); let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap();
let mut websocket = WebSocket::new(websocket);
fn read_exact<T: Transport>(stream: &mut T) { fn read_exact<T: Transport>(stream: &mut T) {
let mut buf = [0u8; 3]; let mut buf = [0u8; 3];
@ -96,7 +171,7 @@ mod tests {
read_exact(&mut websocket); read_exact(&mut websocket);
websocket.close(None).unwrap(); let _ = websocket.transport_shutdown();
jh.join().unwrap(); jh.join().unwrap();
} }