diff --git a/senders/terminal/src/main.rs b/senders/terminal/src/main.rs index db43cf5..59fdbca 100644 --- a/senders/terminal/src/main.rs +++ b/senders/terminal/src/main.rs @@ -1,4 +1,5 @@ use clap::{App, Arg, SubCommand}; +use fcast::transport::WebSocket; use std::collections::HashMap; use std::net::IpAddr; use std::str::FromStr; @@ -204,6 +205,7 @@ fn run() -> Result<(), Box> { MaybeTlsStream::Plain(ref stream) => Some(stream.local_addr()?.ip()), _ => return Err("Established connection type is not plain.".into()), }; + let stream = WebSocket::new(stream); FCastSession::connect(stream)? } _ => return Err("Invalid connection type.".into()), diff --git a/senders/terminal/src/transport.rs b/senders/terminal/src/transport.rs index 8a79c8c..562466a 100644 --- a/senders/terminal/src/transport.rs +++ b/senders/terminal/src/transport.rs @@ -1,6 +1,7 @@ +use std::collections::VecDeque; use std::io::{Read, Write}; use std::net::TcpStream; -use tungstenite::protocol::WebSocket; +use tungstenite::protocol::WebSocket as TWebSocket; use tungstenite::Message; pub trait Transport { @@ -28,28 +29,69 @@ impl Transport for TcpStream { } } -impl Transport for WebSocket { - fn transport_read(&mut self, buf: &mut [u8]) -> Result { - match self.read() { - Ok(Message::Binary(data)) => { - let len = std::cmp::min(buf.len(), data.len()); - buf[..len].copy_from_slice(&data[..len]); - Ok(len) - } - _ => Err(std::io::Error::other("Invalid message type")), +pub struct WebSocket +where + T: Read + Write, +{ + inner: TWebSocket, + buffer: VecDeque, +} + +impl WebSocket +where + T: Read + Write, +{ + pub fn new(web_socket: TWebSocket) -> Self { + Self { + inner: web_socket, + buffer: VecDeque::new(), } } + pub fn read_buffered(&mut self, buf: &mut [u8]) -> Result { + if !self.buffer.is_empty() { + let bytes_to_read = buf.len().min(self.buffer.len()); + assert!(buf.len() >= bytes_to_read); + assert!(self.buffer.len() >= bytes_to_read); + for i in 0..bytes_to_read { + buf[i] = self.buffer.pop_front().unwrap(); // Safe unwrap as bounds was checked previously + } + } else { + match self.inner.read() { + Ok(Message::Binary(data)) => { + let bytes_to_read = buf.len().min(data.len()); + buf.copy_from_slice(&data[..bytes_to_read]); + for rest in data[bytes_to_read..].iter() { + self.buffer.push_back(*rest); + } + } + _ => return Err(std::io::Error::other("Invalid message type")), + } + } + + Ok(buf.len()) + } +} + +impl Transport for WebSocket +where + T: Read + Write, +{ + fn transport_read(&mut self, buf: &mut [u8]) -> Result { + self.read_buffered(buf) + } + fn transport_write(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { - self.write(Message::Binary(buf.to_vec())) + self.inner + .write(Message::Binary(buf.to_vec())) .map_err(std::io::Error::other)?; - self.flush().map_err(std::io::Error::other) + self.inner.flush().map_err(std::io::Error::other) } fn transport_shutdown(&mut self) -> Result<(), std::io::Error> { - self.close(None).map_err(std::io::Error::other)?; + self.inner.close(None).map_err(std::io::Error::other)?; loop { - match self.read() { + match self.inner.read() { Ok(_) => continue, Err(tungstenite::Error::ConnectionClosed) => break, Err(e) => return Err(std::io::Error::other(e)), @@ -62,7 +104,7 @@ impl Transport for WebSocket { fn transport_read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> { let mut total_read = 0; while total_read < buf.len() { - total_read += self.transport_read(&mut buf[total_read..])?; + total_read += self.read_buffered(&mut buf[total_read..])?; } Ok(()) @@ -75,6 +117,38 @@ mod tests { use super::*; + #[test] + fn websocket_read_buffered() { + let jh = std::thread::spawn(|| { + let server = TcpListener::bind("127.0.0.1:51232").unwrap(); + let stream = server.incoming().next().unwrap().unwrap(); + let mut websocket = tungstenite::accept(stream).unwrap(); + websocket + .send(tungstenite::Message::binary([1, 2, 3, 4])) + .unwrap(); + websocket + .send(tungstenite::Message::binary([5, 6, 7, 8])) + .unwrap(); + }); + + let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51232").unwrap(); + let mut websocket = WebSocket::new(websocket); + + let mut buf = [0u8; 2]; + assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2); + assert_eq!(buf, [1, 2]); + assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2); + assert_eq!(buf, [3, 4]); + + let mut buf = [0u8; 4]; + assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 4); + assert_eq!(buf, [5, 6, 7, 8]); + + let _ = websocket.transport_shutdown(); + + jh.join().unwrap(); + } + #[test] fn websocket_read_exact() { let jh = std::thread::spawn(|| { @@ -86,7 +160,8 @@ mod tests { .unwrap(); }); - let (mut websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap(); + let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap(); + let mut websocket = WebSocket::new(websocket); fn read_exact(stream: &mut T) { let mut buf = [0u8; 3]; @@ -96,7 +171,7 @@ mod tests { read_exact(&mut websocket); - websocket.close(None).unwrap(); + let _ = websocket.transport_shutdown(); jh.join().unwrap(); }