mirror of
https://gitlab.com/futo-org/fcast.git
synced 2025-06-24 21:25:23 +00:00
rs-terminal: buffer data read from websocket when underflowing
This commit is contained in:
parent
6529d91eb2
commit
3a7b7675ba
2 changed files with 94 additions and 17 deletions
|
@ -1,4 +1,5 @@
|
||||||
use clap::{App, Arg, SubCommand};
|
use clap::{App, Arg, SubCommand};
|
||||||
|
use fcast::transport::WebSocket;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
@ -204,6 +205,7 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
MaybeTlsStream::Plain(ref stream) => Some(stream.local_addr()?.ip()),
|
MaybeTlsStream::Plain(ref stream) => Some(stream.local_addr()?.ip()),
|
||||||
_ => return Err("Established connection type is not plain.".into()),
|
_ => return Err("Established connection type is not plain.".into()),
|
||||||
};
|
};
|
||||||
|
let stream = WebSocket::new(stream);
|
||||||
FCastSession::connect(stream)?
|
FCastSession::connect(stream)?
|
||||||
}
|
}
|
||||||
_ => return Err("Invalid connection type.".into()),
|
_ => return Err("Invalid connection type.".into()),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
|
use std::collections::VecDeque;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use std::net::TcpStream;
|
use std::net::TcpStream;
|
||||||
use tungstenite::protocol::WebSocket;
|
use tungstenite::protocol::WebSocket as TWebSocket;
|
||||||
use tungstenite::Message;
|
use tungstenite::Message;
|
||||||
|
|
||||||
pub trait Transport {
|
pub trait Transport {
|
||||||
|
@ -28,28 +29,69 @@ impl Transport for TcpStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Read + Write> Transport for WebSocket<T> {
|
pub struct WebSocket<T>
|
||||||
fn transport_read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
where
|
||||||
match self.read() {
|
T: Read + Write,
|
||||||
Ok(Message::Binary(data)) => {
|
{
|
||||||
let len = std::cmp::min(buf.len(), data.len());
|
inner: TWebSocket<T>,
|
||||||
buf[..len].copy_from_slice(&data[..len]);
|
buffer: VecDeque<u8>,
|
||||||
Ok(len)
|
}
|
||||||
}
|
|
||||||
_ => Err(std::io::Error::other("Invalid message type")),
|
impl<T> WebSocket<T>
|
||||||
|
where
|
||||||
|
T: Read + Write,
|
||||||
|
{
|
||||||
|
pub fn new(web_socket: TWebSocket<T>) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: web_socket,
|
||||||
|
buffer: VecDeque::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn read_buffered(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
||||||
|
if !self.buffer.is_empty() {
|
||||||
|
let bytes_to_read = buf.len().min(self.buffer.len());
|
||||||
|
assert!(buf.len() >= bytes_to_read);
|
||||||
|
assert!(self.buffer.len() >= bytes_to_read);
|
||||||
|
for i in 0..bytes_to_read {
|
||||||
|
buf[i] = self.buffer.pop_front().unwrap(); // Safe unwrap as bounds was checked previously
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match self.inner.read() {
|
||||||
|
Ok(Message::Binary(data)) => {
|
||||||
|
let bytes_to_read = buf.len().min(data.len());
|
||||||
|
buf.copy_from_slice(&data[..bytes_to_read]);
|
||||||
|
for rest in data[bytes_to_read..].iter() {
|
||||||
|
self.buffer.push_back(*rest);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(std::io::Error::other("Invalid message type")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(buf.len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Transport for WebSocket<T>
|
||||||
|
where
|
||||||
|
T: Read + Write,
|
||||||
|
{
|
||||||
|
fn transport_read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
||||||
|
self.read_buffered(buf)
|
||||||
|
}
|
||||||
|
|
||||||
fn transport_write(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
|
fn transport_write(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
|
||||||
self.write(Message::Binary(buf.to_vec()))
|
self.inner
|
||||||
|
.write(Message::Binary(buf.to_vec()))
|
||||||
.map_err(std::io::Error::other)?;
|
.map_err(std::io::Error::other)?;
|
||||||
self.flush().map_err(std::io::Error::other)
|
self.inner.flush().map_err(std::io::Error::other)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn transport_shutdown(&mut self) -> Result<(), std::io::Error> {
|
fn transport_shutdown(&mut self) -> Result<(), std::io::Error> {
|
||||||
self.close(None).map_err(std::io::Error::other)?;
|
self.inner.close(None).map_err(std::io::Error::other)?;
|
||||||
loop {
|
loop {
|
||||||
match self.read() {
|
match self.inner.read() {
|
||||||
Ok(_) => continue,
|
Ok(_) => continue,
|
||||||
Err(tungstenite::Error::ConnectionClosed) => break,
|
Err(tungstenite::Error::ConnectionClosed) => break,
|
||||||
Err(e) => return Err(std::io::Error::other(e)),
|
Err(e) => return Err(std::io::Error::other(e)),
|
||||||
|
@ -62,7 +104,7 @@ impl<T: Read + Write> Transport for WebSocket<T> {
|
||||||
fn transport_read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> {
|
fn transport_read_exact(&mut self, buf: &mut [u8]) -> Result<(), std::io::Error> {
|
||||||
let mut total_read = 0;
|
let mut total_read = 0;
|
||||||
while total_read < buf.len() {
|
while total_read < buf.len() {
|
||||||
total_read += self.transport_read(&mut buf[total_read..])?;
|
total_read += self.read_buffered(&mut buf[total_read..])?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -75,6 +117,38 @@ mod tests {
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn websocket_read_buffered() {
|
||||||
|
let jh = std::thread::spawn(|| {
|
||||||
|
let server = TcpListener::bind("127.0.0.1:51232").unwrap();
|
||||||
|
let stream = server.incoming().next().unwrap().unwrap();
|
||||||
|
let mut websocket = tungstenite::accept(stream).unwrap();
|
||||||
|
websocket
|
||||||
|
.send(tungstenite::Message::binary([1, 2, 3, 4]))
|
||||||
|
.unwrap();
|
||||||
|
websocket
|
||||||
|
.send(tungstenite::Message::binary([5, 6, 7, 8]))
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51232").unwrap();
|
||||||
|
let mut websocket = WebSocket::new(websocket);
|
||||||
|
|
||||||
|
let mut buf = [0u8; 2];
|
||||||
|
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2);
|
||||||
|
assert_eq!(buf, [1, 2]);
|
||||||
|
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 2);
|
||||||
|
assert_eq!(buf, [3, 4]);
|
||||||
|
|
||||||
|
let mut buf = [0u8; 4];
|
||||||
|
assert_eq!(websocket.read_buffered(&mut buf).unwrap(), 4);
|
||||||
|
assert_eq!(buf, [5, 6, 7, 8]);
|
||||||
|
|
||||||
|
let _ = websocket.transport_shutdown();
|
||||||
|
|
||||||
|
jh.join().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn websocket_read_exact() {
|
fn websocket_read_exact() {
|
||||||
let jh = std::thread::spawn(|| {
|
let jh = std::thread::spawn(|| {
|
||||||
|
@ -86,7 +160,8 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
let (mut websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap();
|
let (websocket, _) = tungstenite::connect("ws://127.0.0.1:51234").unwrap();
|
||||||
|
let mut websocket = WebSocket::new(websocket);
|
||||||
|
|
||||||
fn read_exact<T: Transport>(stream: &mut T) {
|
fn read_exact<T: Transport>(stream: &mut T) {
|
||||||
let mut buf = [0u8; 3];
|
let mut buf = [0u8; 3];
|
||||||
|
@ -96,7 +171,7 @@ mod tests {
|
||||||
|
|
||||||
read_exact(&mut websocket);
|
read_exact(&mut websocket);
|
||||||
|
|
||||||
websocket.close(None).unwrap();
|
let _ = websocket.transport_shutdown();
|
||||||
|
|
||||||
jh.join().unwrap();
|
jh.join().unwrap();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue