From d7ae48f1442e5082efb9e8de91d3035323a7f605 Mon Sep 17 00:00:00 2001 From: Malachi Thomas Date: Thu, 4 Jul 2024 21:07:07 -0700 Subject: [PATCH 1/3] Initial server implementation --- .gitignore | 5 + Cargo.toml | 24 ++++ src/bin/chat-clt.rs | 3 + src/bin/chat-svr.rs | 22 ++++ src/lib.rs | 17 +++ src/server.rs | 281 ++++++++++++++++++++++++++++++++++++++++++++ tests/main.rs | 0 7 files changed, 352 insertions(+) create mode 100644 Cargo.toml create mode 100644 src/bin/chat-clt.rs create mode 100644 src/bin/chat-svr.rs create mode 100644 src/lib.rs create mode 100644 src/server.rs create mode 100644 tests/main.rs diff --git a/.gitignore b/.gitignore index 6985cf1..196e176 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,8 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + + +# Added by cargo + +/target diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b7c7bb0 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "simple-chat" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "chat-svr" + +[[bin]] +name = "chat-clt" + +[dependencies] +anyhow = { version = "1.0.86", default-features = false, features = ["std"] } +clap = { version = "4.5.8", default-features = false, features = ["std", "derive"] } +futures = { version = "0.3.30", default-features = false } +futures-util = { version = "0.3.30", default-features = false, features = ["sink"] } +tokio = { version = "1.38.0", default-features = false, features = ["full"] } +tokio-stream = { version = "0.1.15", default-features = false, features = ["net"] } +tokio-util = { version = "0.7.11", default-features = false, features = ["codec"] } +serde = { version = "1.0.194", default-features = false, features = ["derive"] } +serde_json = { version = "1.0.120", default-features = false, features = ["std"] } + +[workspace.lints.clippy] +wildcard_imports = "deny" diff --git a/src/bin/chat-clt.rs b/src/bin/chat-clt.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/src/bin/chat-clt.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/src/bin/chat-svr.rs b/src/bin/chat-svr.rs new file mode 100644 index 0000000..c0028ad --- /dev/null +++ b/src/bin/chat-svr.rs @@ -0,0 +1,22 @@ +use anyhow::Result; +use clap::Parser; + +#[derive(Parser)] +#[command(version, about, long_about = None)] +struct Cli { + /// Set the server port to listen on. Defaults to `8080`. + #[arg(short)] + port: Option, +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::parse(); + let port = if let Some(port) = cli.port { + port + } else { + 8080 + }; + let addr = format!("127.0.0.1:{port}"); + simple_chat::server::run(addr).await +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c7d16ab --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +pub mod server; + +#[derive(Serialize, Deserialize)] +pub enum ServerMessage { + Message(String), + Error(String), + Success, +} + +#[derive(Serialize, Deserialize)] +pub enum ClientMessage { + Connect(String), + SendMsg(String), + Leave, +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..27ad797 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,281 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::Result; +use futures_util::sink::SinkExt; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::sync::Mutex; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_stream::StreamExt; +use tokio_util::codec::{Framed, LinesCodec}; + +use crate::{ClientMessage, ServerMessage}; + +/// Implementation Note: Originally tried to use the tokio::sync::broadcast +/// channel type, but the requirement is not to send the message to the sender. +/// Similarly, trying to store the actual network handle inside this set +/// results in some bad borrowing failures. +struct ClientSet { + inner: Mutex>>, +} + +impl ClientSet { + fn new() -> Self { + Self { + inner: Mutex::new(HashMap::new()), + } + } + + /// Registers the nick if it is not already in use + async fn register(&self, nick: &str, sender: Sender) -> Result<()> { + let mut inner = self.inner.lock().await; + if inner.contains_key(nick) { + return Err(anyhow::Error::msg("nick already registered: {nick}")); + } + + inner.insert(nick.to_string(), sender); + + Ok(()) + } + + /// De-registers the nick, a nick that is not registered is not remov + async fn deregister(&self, nick: &str) { + let mut inner = self.inner.lock().await; + inner.remove(nick); + } +} + +/// Start the server to listen for incoming connections +pub async fn run(addr: A) -> Result<()> { + let listener = TcpListener::bind(addr).await?; + let mut listener = TcpListenerStream::new(listener); + + let nick_set = Arc::new(ClientSet::new()); + + while let Some(socket) = listener.next().await { + let socket = socket?; + let mut socket = Framed::new(socket, LinesCodec::new()); + if let Some(msg) = socket.next().await { + let (tx, rx) = mpsc::channel(16); + if let ClientMessage::Connect(nick) = serde_json::from_str(&msg?)? { + let mut clt = ClientHandle { + nick, + nick_set: nick_set.clone(), + socket, + receiver: rx, + }; + + match nick_set.register(&clt.nick, tx).await { + Ok(()) => { + clt.send(ServerMessage::Success).await?; + let connect = format!("{} has joined the channel", clt.nick); + clt.publish(&connect).await?; + tokio::spawn(async move { clt.handle_client().await }); + } + Err(err) => { + clt.send(ServerMessage::Error(err.to_string())).await?; + continue; + } + } + } + + return Err(anyhow::Error::msg( + "expected nick registration, received: {:?}", + )); + } + } + + Ok(()) +} + +struct ClientHandle { + nick: String, + nick_set: Arc, + socket: Framed, + receiver: Receiver, +} + +impl ClientHandle { + async fn send(&mut self, msg: ServerMessage) -> Result<()> { + let msg_json = serde_json::to_string(&msg)?; + self.socket.send(msg_json).await?; + Ok(()) + } + + /// Loops through the list of other clients to send them the message + async fn publish(&mut self, msg: &str) -> Result<()> { + let mut inner = self.nick_set.inner.lock().await; + for (nick, sender) in inner.iter_mut() { + if *nick == self.nick { + continue; + } + + sender.send(msg.to_string()).await?; + } + Ok(()) + } + + async fn handle_client(mut self) -> Result<()> { + loop { + tokio::select!( + incoming = self.socket.next() => { + if let Some(incoming) = incoming { + match serde_json::from_str(&incoming?)? { + ClientMessage::SendMsg(incoming) => { + self.publish(&incoming).await?; + } + ClientMessage::Leave => { + self.nick_set.deregister(&self.nick).await; + self.send(ServerMessage::Success).await?; + let disconnect = format!("{} has disconnected", self.nick); + self.publish(&disconnect).await?; + break; + } + _ => todo!(), + } + } + }, + msg = self.receiver.recv() => { + if let Some(msg) = msg { + self.send(ServerMessage::Message(msg)).await?; + } + }, + ); + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + //use futures::sink::SinkExt; + //use tokio::net::TcpStream; + //use tokio_stream::StreamExt; + //use tokio_util::codec::{Framed, LinesCodec}; + // + //use simple_chat::ClientMessage; + + #[tokio::test] + #[should_panic] + async fn register_twice() { + use tokio::sync::mpsc; + let (tx, _) = mpsc::channel(16); + let nick_set = std::sync::Arc::new(crate::server::ClientSet::new()); + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap(); + + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap(); + } + + //#[tokio::test] + //async fn client_connect() { + // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); + // // Just waiting for the server to start, this should be swapped out + // // since it slows down the tests + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick1".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("{reply:?}"); + // + // let send_msg = ClientMessage::SendMsg("TestMsg".to_string()); + // let send_msg_json = serde_json::to_string(&send_msg).unwrap(); + // sock.send(send_msg_json).await.unwrap(); + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + //} + // + //#[tokio::test] + //async fn register_same_nick() { + // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); + // // Just waiting for the server to start, this should be swapped out + // // since it slows down the tests + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick1".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("{reply:?}"); + // + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick1".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("{reply:?}"); + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + //} + // + //#[tokio::test] + //async fn register_same_nick_deregister() { + // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); + // // Just waiting for the server to start, this should be swapped out + // // since it slows down the tests + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick1".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("{reply:?}"); + // + // let leave = ClientMessage::Leave; + // let leave_json = serde_json::to_string(&leave).unwrap(); + // sock.send(leave_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("{reply:?}"); + // + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick1".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("{reply:?}"); + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + //} + // + //#[tokio::test] + //async fn send_multiple() { + // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); + // // Just waiting for the server to start, this should be swapped out + // // since it slows down the tests + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + // tokio::spawn(async move { + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick1".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("Nick1: {reply:?}"); + // + // let send_msg = ClientMessage::SendMsg("TestMsg".to_string()); + // let send_msg_json = serde_json::to_string(&send_msg).unwrap(); + // sock.send(send_msg_json).await.unwrap(); + // }); + // + // tokio::spawn(async move { + // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + // let mut sock = Framed::new(sock, LinesCodec::new()); + // let register = ClientMessage::Connect("Nick2".to_string()); + // let register_json = serde_json::to_string(®ister).unwrap(); + // sock.send(register_json).await.unwrap(); + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("Nick2: {reply:?}"); + // + // let reply = sock.next().await.unwrap().unwrap(); + // eprintln!("Nick2: {reply:?}"); + // }); + // + // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + //} +} diff --git a/tests/main.rs b/tests/main.rs new file mode 100644 index 0000000..e69de29 From af114fc0d659487c3cf5014551526afddc58e9a9 Mon Sep 17 00:00:00 2001 From: Malachi Thomas Date: Fri, 5 Jul 2024 22:15:17 -0700 Subject: [PATCH 2/3] Update unit tests for server, create UI for client --- Cargo.toml | 3 + src/bin/chat-clt.rs | 227 +++++++++++++++++++++++++++++++++++++++++++- src/client.rs | 19 ++++ src/lib.rs | 1 + src/server.rs | 224 ++++++++++++++----------------------------- tests/main.rs | 1 + 6 files changed, 321 insertions(+), 154 deletions(-) create mode 100644 src/client.rs diff --git a/Cargo.toml b/Cargo.toml index b7c7bb0..05eff78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,11 @@ futures-util = { version = "0.3.30", default-features = false, features = ["sink tokio = { version = "1.38.0", default-features = false, features = ["full"] } tokio-stream = { version = "0.1.15", default-features = false, features = ["net"] } tokio-util = { version = "0.7.11", default-features = false, features = ["codec"] } +ratatui = { version = "0.27.0", default-features = false, features = ["crossterm"] } +regex = { version = "1.10.5", default-features = false, features = ["unicode-perl"] } serde = { version = "1.0.194", default-features = false, features = ["derive"] } serde_json = { version = "1.0.120", default-features = false, features = ["std"] } +tui-input = { version = "0.9.0", default-features = false, features = ["crossterm"] } [workspace.lints.clippy] wildcard_imports = "deny" diff --git a/src/bin/chat-clt.rs b/src/bin/chat-clt.rs index e7a11a9..c8ea4bc 100644 --- a/src/bin/chat-clt.rs +++ b/src/bin/chat-clt.rs @@ -1,3 +1,226 @@ -fn main() { - println!("Hello, world!"); +use std::io::{stdout, Stdout}; +use std::sync::Mutex; + +use anyhow::Result; +use ratatui::{ + backend::CrosstermBackend, + crossterm::{ + event::{self, KeyCode}, + execute, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + }, + prelude::{Constraint, Layout}, + style::{Style, Stylize}, + symbols::border, + text::{Line, Span, Text}, + widgets::{Block, List, ListItem, Paragraph}, + Frame, Terminal, +}; +use regex::Regex; +use tui_input::{backend::crossterm::EventHandler, Input}; + +use simple_chat::client::Connection; + +static REGEX: std::sync::OnceLock = std::sync::OnceLock::new(); + +type Tui = Terminal>; + +enum ClientCommand { + Connect(String, A), + Send(String), + Leave, +} + +impl core::str::FromStr for ClientCommand { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let input_re = REGEX.get_or_init(|| { + Regex::new( + r#"(?x) + ^(leave)\s*$ | + ^(connect)\s+(.*)@(.*)$ | + ^(send)\s+(.*) + "#, + ) + .unwrap() + }); + + let captures = input_re.captures(s).map(|captures| { + captures + .iter() + .skip(1) + .flatten() + .map(|c| c.as_str()) + .collect::>() + }); + + let ret = match captures.as_deref() { + Some(["leave"]) => Ok(Self::Leave), + Some(["connect", nick, addr]) => Ok(Self::Connect(nick.to_string(), addr.to_string())), + Some(["send", message]) => Ok(Self::Send(message.to_string())), + _ => { + anyhow::bail!("invalid command") + } + }; + + ret + } +} + +#[derive(Debug, Default)] +struct Messages { + messages: Mutex>, +} + +impl Messages { + fn push(&mut self, nick: &str, msg: &str) { + self.messages + .lock() + .unwrap() + .push((nick.to_string(), msg.to_string())) + } +} + +#[derive(Debug, Default)] +pub struct App { + input: Input, + messages: Messages, + connection: Option, + exit: bool, +} + +impl App { + fn run(&mut self, terminal: &mut Tui) -> Result<()> { + while !self.exit { + terminal.draw(|frame| ui(frame, self))?; + if event::poll(std::time::Duration::from_millis(16))? { + self.handle_events()?; + } + } + + Ok(()) + } + + fn handle_events(&mut self) -> Result<()> { + if let event::Event::Key(key) = event::read()? { + match key.code { + KeyCode::Enter => { + let input = self.input.value().to_string(); + match input.parse::>() { + Ok(ClientCommand::Leave) => self.exit(), + Ok(ClientCommand::Connect(nick, addr)) => { + if self.connection.is_none() { + self.messages + .push("INFO", &format!("Connecting to {addr} as {nick}")); + if let Ok(connection) = Connection::connect(&nick, addr) { + self.connection = Some(connection); + } else { + // TODO: Display the specific error + self.messages.push("ERROR", "Failed to connect to server"); + } + } else { + self.messages.push("ERROR", "Already connected to server"); + } + } + Ok(ClientCommand::Send(message)) => { + if let Some(connection) = &self.connection { + self.messages.push(&connection.nick, &message); + connection.send(&message).unwrap(); + } else { + self.messages.push("ERROR", "Not connected to a server"); + } + } + _ => { + self.messages.push("ERROR", "Invalid command"); + } + } + self.input.reset(); + } + _ => { + self.input.handle_event(&event::Event::Key(key)); + } + } + } + + Ok(()) + } + + fn exit(&mut self) { + self.exit = true; + } +} + +fn ui(f: &mut Frame, app: &App) { + let vertical = Layout::vertical([ + Constraint::Min(1), + Constraint::Length(3), + Constraint::Length(1), + ]); + let [message_area, input_area, help_area] = vertical.areas(f.size()); + + let header = Text::from(Line::from(vec![ + "Type ".into(), + "connect @:".bold().green(), + " to connect to a server, type ".into(), + "leave".bold().red(), + " to exit.".into(), + ])) + .patch_style(Style::default()); + let help_message = Paragraph::new(header); + f.render_widget(help_message, help_area); + + let width = input_area.width.max(3) - 3; + let scroll = app.input.visual_scroll(width as usize); + let input_block = Block::bordered().title(" Input ").border_set(border::THICK); + let input = Paragraph::new(app.input.value()) + .scroll((0, scroll as u16)) + .block(input_block); + f.render_widget(input, input_area); + + f.set_cursor( + input_area.x + ((app.input.visual_cursor()).max(scroll) - scroll) as u16 + 1, + input_area.y + 1, + ); + + let message_block = Block::bordered() + .title(" Messages ") + .border_set(border::THICK); + let messages: Vec = app + .messages + .messages + .lock() + .unwrap() + .iter() + .map(|(n, m)| { + let content = vec![Line::from(Span::raw(format!("{n}: {m}")))]; + ListItem::new(content) + }) + .collect(); + let messages = List::new(messages).block(message_block); + f.render_widget(messages, message_area); +} + +fn main() -> Result<()> { + let mut terminal = init()?; + terminal.clear()?; + + let mut app = App::default(); + let res = app.run(&mut terminal); + + restore()?; + + res +} + +fn init() -> Result { + execute!(stdout(), EnterAlternateScreen)?; + enable_raw_mode()?; + Terminal::new(CrosstermBackend::new(stdout())).map_err(anyhow::Error::from) +} + +fn restore() -> Result<()> { + execute!(stdout(), LeaveAlternateScreen)?; + disable_raw_mode()?; + Ok(()) } diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..c3194c8 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,19 @@ +use std::net::ToSocketAddrs; + +use anyhow::Result; + +#[derive(Debug, Default)] +pub struct Connection { + pub nick: String, +} + +impl Connection { + pub fn connect(nick: &str, addr: A) -> Result { + todo!(); + } + + pub fn send(&self, msg: &str) -> Result<()> { + // TODO: Send the message to the server + todo!(); + } +} diff --git a/src/lib.rs b/src/lib.rs index c7d16ab..c126cb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use serde::{Deserialize, Serialize}; +pub mod client; pub mod server; #[derive(Serialize, Deserialize)] diff --git a/src/server.rs b/src/server.rs index 27ad797..06d3422 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,17 +1,24 @@ use std::collections::HashMap; +use std::marker::Unpin; use std::sync::Arc; use anyhow::Result; use futures_util::sink::SinkExt; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio::sync::Mutex; -use tokio_stream::wrappers::TcpListenerStream; -use tokio_stream::StreamExt; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpListener, ToSocketAddrs}, + sync::{ + mpsc::{self, Receiver, Sender}, + Mutex, + }, +}; +use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; use tokio_util::codec::{Framed, LinesCodec}; use crate::{ClientMessage, ServerMessage}; +type ChatFrame = Framed; + /// Implementation Note: Originally tried to use the tokio::sync::broadcast /// channel type, but the requirement is not to send the message to the sender. /// Similarly, trying to store the actual network handle inside this set @@ -54,49 +61,66 @@ pub async fn run(addr: A) -> Result<()> { let nick_set = Arc::new(ClientSet::new()); while let Some(socket) = listener.next().await { - let socket = socket?; - let mut socket = Framed::new(socket, LinesCodec::new()); - if let Some(msg) = socket.next().await { - let (tx, rx) = mpsc::channel(16); - if let ClientMessage::Connect(nick) = serde_json::from_str(&msg?)? { - let mut clt = ClientHandle { - nick, - nick_set: nick_set.clone(), - socket, - receiver: rx, - }; + if let Ok(socket) = socket { + handle_incoming(socket, nick_set.clone()).await?; + } + } - match nick_set.register(&clt.nick, tx).await { - Ok(()) => { - clt.send(ServerMessage::Success).await?; - let connect = format!("{} has joined the channel", clt.nick); - clt.publish(&connect).await?; - tokio::spawn(async move { clt.handle_client().await }); - } - Err(err) => { - clt.send(ServerMessage::Error(err.to_string())).await?; - continue; - } + Ok(()) +} + +async fn handle_incoming(socket: S, nick_set: Arc) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let mut socket = Framed::new(socket, LinesCodec::new()); + if let Some(msg) = socket.next().await { + if let ClientMessage::Connect(nick) = serde_json::from_str(&msg?)? { + let (mut clt, tx) = ClientHandle::new(&nick, nick_set, socket); + match clt.nick_set.register(&clt.nick, tx).await { + Ok(()) => { + clt.send(ServerMessage::Success).await?; + let connect = format!("{} has joined the channel", clt.nick); + clt.publish(&connect).await?; + tokio::spawn(async move { clt.handle_client().await }); + } + Err(err) => { + clt.send(ServerMessage::Error(err.to_string())).await?; } } - - return Err(anyhow::Error::msg( - "expected nick registration, received: {:?}", - )); } + + return Err(anyhow::Error::msg( + "expected nick registration, received: {:?}", + )); } Ok(()) } -struct ClientHandle { +struct ClientHandle { nick: String, nick_set: Arc, - socket: Framed, + socket: ChatFrame, receiver: Receiver, } -impl ClientHandle { +impl ClientHandle +where + T: AsyncRead + AsyncWrite + Unpin, +{ + fn new(nick: &str, nick_set: Arc, socket: ChatFrame) -> (Self, Sender) { + let (tx, rx) = mpsc::channel::(16); + let clt = ClientHandle { + nick: nick.to_string(), + nick_set, + socket, + receiver: rx, + }; + + (clt, tx) + } + async fn send(&mut self, msg: ServerMessage) -> Result<()> { let msg_json = serde_json::to_string(&msg)?; self.socket.send(msg_json).await?; @@ -149,15 +173,7 @@ impl ClientHandle { #[cfg(test)] mod test { - //use futures::sink::SinkExt; - //use tokio::net::TcpStream; - //use tokio_stream::StreamExt; - //use tokio_util::codec::{Framed, LinesCodec}; - // - //use simple_chat::ClientMessage; - #[tokio::test] - #[should_panic] async fn register_twice() { use tokio::sync::mpsc; let (tx, _) = mpsc::channel(16); @@ -166,116 +182,20 @@ mod test { nick_set.register(nick, tx.clone()).await.unwrap(); let nick = "Nick1"; - nick_set.register(nick, tx.clone()).await.unwrap(); + nick_set.register(nick, tx.clone()).await.unwrap_err(); } - //#[tokio::test] - //async fn client_connect() { - // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); - // // Just waiting for the server to start, this should be swapped out - // // since it slows down the tests - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick1".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("{reply:?}"); - // - // let send_msg = ClientMessage::SendMsg("TestMsg".to_string()); - // let send_msg_json = serde_json::to_string(&send_msg).unwrap(); - // sock.send(send_msg_json).await.unwrap(); - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - //} - // - //#[tokio::test] - //async fn register_same_nick() { - // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); - // // Just waiting for the server to start, this should be swapped out - // // since it slows down the tests - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick1".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("{reply:?}"); - // - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick1".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("{reply:?}"); - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - //} - // - //#[tokio::test] - //async fn register_same_nick_deregister() { - // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); - // // Just waiting for the server to start, this should be swapped out - // // since it slows down the tests - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick1".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("{reply:?}"); - // - // let leave = ClientMessage::Leave; - // let leave_json = serde_json::to_string(&leave).unwrap(); - // sock.send(leave_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("{reply:?}"); - // - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick1".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("{reply:?}"); - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - //} - // - //#[tokio::test] - //async fn send_multiple() { - // tokio::spawn(async move { crate::run("127.0.0.1:8080").await }); - // // Just waiting for the server to start, this should be swapped out - // // since it slows down the tests - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - // tokio::spawn(async move { - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick1".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("Nick1: {reply:?}"); - // - // let send_msg = ClientMessage::SendMsg("TestMsg".to_string()); - // let send_msg_json = serde_json::to_string(&send_msg).unwrap(); - // sock.send(send_msg_json).await.unwrap(); - // }); - // - // tokio::spawn(async move { - // let sock = TcpStream::connect("127.0.0.1:8080").await.unwrap(); - // let mut sock = Framed::new(sock, LinesCodec::new()); - // let register = ClientMessage::Connect("Nick2".to_string()); - // let register_json = serde_json::to_string(®ister).unwrap(); - // sock.send(register_json).await.unwrap(); - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("Nick2: {reply:?}"); - // - // let reply = sock.next().await.unwrap().unwrap(); - // eprintln!("Nick2: {reply:?}"); - // }); - // - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - //} + #[tokio::test] + async fn register_deregister() { + use tokio::sync::mpsc; + let (tx, _) = mpsc::channel(16); + let nick_set = std::sync::Arc::new(crate::server::ClientSet::new()); + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap(); + + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap_err(); + nick_set.deregister(nick).await; + nick_set.register(nick, tx.clone()).await.unwrap(); + } } diff --git a/tests/main.rs b/tests/main.rs index e69de29..8b13789 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -0,0 +1 @@ + From aaf2a4af78cc8feb9682180ad0c282e6d71bf389 Mon Sep 17 00:00:00 2001 From: Malachi Thomas Date: Sat, 6 Jul 2024 16:10:00 -0700 Subject: [PATCH 3/3] Add basic tests, complete implementation --- src/bin/chat-clt.rs | 57 +++++++++---------- src/client.rs | 112 +++++++++++++++++++++++++++++++++++--- src/lib.rs | 13 +++-- src/server.rs | 73 ++++++++++++++----------- tests/integration_test.rs | 32 +++++++++++ tests/main.rs | 1 - 6 files changed, 215 insertions(+), 73 deletions(-) create mode 100644 tests/integration_test.rs delete mode 100644 tests/main.rs diff --git a/src/bin/chat-clt.rs b/src/bin/chat-clt.rs index c8ea4bc..29cf8eb 100644 --- a/src/bin/chat-clt.rs +++ b/src/bin/chat-clt.rs @@ -1,7 +1,7 @@ use std::io::{stdout, Stdout}; -use std::sync::Mutex; +use std::sync::Arc; -use anyhow::Result; +use anyhow::{Context, Result}; use ratatui::{ backend::CrosstermBackend, crossterm::{ @@ -19,7 +19,10 @@ use ratatui::{ use regex::Regex; use tui_input::{backend::crossterm::EventHandler, Input}; -use simple_chat::client::Connection; +use simple_chat::{ + client::{Connection, Messages}, + ClientMessage, +}; static REGEX: std::sync::OnceLock = std::sync::OnceLock::new(); @@ -68,24 +71,10 @@ impl core::str::FromStr for ClientCommand { } } -#[derive(Debug, Default)] -struct Messages { - messages: Mutex>, -} - -impl Messages { - fn push(&mut self, nick: &str, msg: &str) { - self.messages - .lock() - .unwrap() - .push((nick.to_string(), msg.to_string())) - } -} - #[derive(Debug, Default)] pub struct App { input: Input, - messages: Messages, + messages: Arc, connection: Option, exit: bool, } @@ -113,20 +102,24 @@ impl App { if self.connection.is_none() { self.messages .push("INFO", &format!("Connecting to {addr} as {nick}")); - if let Ok(connection) = Connection::connect(&nick, addr) { - self.connection = Some(connection); - } else { - // TODO: Display the specific error - self.messages.push("ERROR", "Failed to connect to server"); + match Connection::connect(&nick, addr, self.messages.clone()) { + Ok(connection) => self.connection = Some(connection), + Err(e) => { + let err = format!("Failed to connect to server: {e}"); + self.messages.push("ERROR", &err); + } } } else { self.messages.push("ERROR", "Already connected to server"); } } Ok(ClientCommand::Send(message)) => { - if let Some(connection) = &self.connection { + if let Some(connection) = &mut self.connection { self.messages.push(&connection.nick, &message); - connection.send(&message).unwrap(); + if let Err(e) = connection.send(ClientMessage::SendMsg(message)) { + let err = format!("Failed to send message: {e}"); + self.messages.push("ERROR", &err); + } } else { self.messages.push("ERROR", "Not connected to a server"); } @@ -148,6 +141,16 @@ impl App { fn exit(&mut self) { self.exit = true; + if let Some(connection) = &mut self.connection { + connection + .send(ClientMessage::Leave) + .or_else(|e| { + eprintln!("Failed to send disconnect message: {e:?}"); + anyhow::Ok(()) + }) + .context("should always return Ok") + .unwrap(); + } } } @@ -188,9 +191,7 @@ fn ui(f: &mut Frame, app: &App) { .border_set(border::THICK); let messages: Vec = app .messages - .messages - .lock() - .unwrap() + .get() .iter() .map(|(n, m)| { let content = vec![Line::from(Span::raw(format!("{n}: {m}")))]; diff --git a/src/client.rs b/src/client.rs index c3194c8..a2ebce6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,19 +1,117 @@ -use std::net::ToSocketAddrs; +use std::sync::Arc; -use anyhow::Result; +use anyhow::{Context, Result}; +use futures_util::sink::SinkExt; +use tokio::{ + net::{tcp::OwnedWriteHalf, TcpStream, ToSocketAddrs}, + runtime::Runtime, +}; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec}; + +use crate::{ClientMessage, ServerMessage}; #[derive(Debug, Default)] +pub struct Messages { + messages: std::sync::Mutex>, +} + +impl Messages { + pub fn push(&self, nick: &str, msg: &str) { + self.messages + .lock() + .unwrap() + .push((nick.to_string(), msg.to_string())) + } + + pub fn get(&self) -> Vec<(String, String)> { + self.messages.lock().unwrap().clone() + } +} + +#[derive(Debug)] pub struct Connection { pub nick: String, + sender: FramedWrite, + rt: Arc, } impl Connection { - pub fn connect(nick: &str, addr: A) -> Result { - todo!(); + pub fn connect(nick: &str, addr: A, messages: Arc) -> Result { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + let socket = rt.block_on(async { TcpStream::connect(addr).await })?; + let (receiver, sender) = socket.into_split(); + let mut receiver = FramedRead::new(receiver, LinesCodec::new()); + let sender = FramedWrite::new(sender, LinesCodec::new()); + let mut conn = Connection { + nick: nick.to_string(), + sender, + rt: Arc::new(rt), + }; + + conn.send(ClientMessage::Connect(conn.nick.clone()))?; + let reply = conn + .rt + .block_on(async { receiver.next().await }) + .ok_or(anyhow::Error::msg("error"))?; + match serde_json::from_str(&reply?)? { + ServerMessage::Connected => { + messages.push("INFO", "Connected to server."); + } + ServerMessage::Error(err) => { + return Err(anyhow::Error::msg(format!( + "failed to connect to server: {err}" + ))); + } + _ => { + return Err(anyhow::Error::msg( + "received unexpected message".to_string(), + )); + } + } + + let rt_inner = conn.rt.clone(); + std::thread::spawn(move || { + let listener_task = rt_inner.spawn(async move { + while let Some(incoming) = receiver.next().await { + let message = serde_json::from_str(&incoming?)?; + match message { + ServerMessage::Message(nick, msg) => { + messages.push(&nick, &msg); + } + ServerMessage::Error(err) => { + messages.push("ERROR", &err); + } + ServerMessage::Join(nick) => { + messages.push("SERVER", &format!("{nick} has joined the channel.")); + } + ServerMessage::Leave(nick) => { + messages.push("SERVER", &format!("{nick} has left the channel.")); + } + _ => { + // We don't care about any other message types in here + continue; + } + } + } + messages.push("INFO", "Server has closed the connection"); + anyhow::Ok(()) + }); + + rt_inner.block_on(listener_task)??; + anyhow::Ok(()) + }); + + Ok(conn) } - pub fn send(&self, msg: &str) -> Result<()> { - // TODO: Send the message to the server - todo!(); + pub fn send(&mut self, msg: ClientMessage) -> Result<()> { + let msg_json = serde_json::to_string(&msg)?; + self.rt + .block_on(async { self.sender.send(msg_json).await }) + .context("sending message to server") } } diff --git a/src/lib.rs b/src/lib.rs index c126cb0..af6a783 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,21 @@ use serde::{Deserialize, Serialize}; +use tokio_util::codec::{Framed, LinesCodec}; pub mod client; pub mod server; -#[derive(Serialize, Deserialize)] +type ChatFrame = Framed; + +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ServerMessage { - Message(String), + Message(String, String), Error(String), - Success, + Join(String), + Leave(String), + Connected, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub enum ClientMessage { Connect(String), SendMsg(String), diff --git a/src/server.rs b/src/server.rs index 06d3422..254c53c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,16 +15,14 @@ use tokio::{ use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; use tokio_util::codec::{Framed, LinesCodec}; -use crate::{ClientMessage, ServerMessage}; - -type ChatFrame = Framed; +use crate::{ChatFrame, ClientMessage, ServerMessage}; /// Implementation Note: Originally tried to use the tokio::sync::broadcast /// channel type, but the requirement is not to send the message to the sender. /// Similarly, trying to store the actual network handle inside this set /// results in some bad borrowing failures. struct ClientSet { - inner: Mutex>>, + inner: Mutex>>, } impl ClientSet { @@ -35,10 +33,12 @@ impl ClientSet { } /// Registers the nick if it is not already in use - async fn register(&self, nick: &str, sender: Sender) -> Result<()> { + async fn register(&self, nick: &str, sender: Sender) -> Result<()> { let mut inner = self.inner.lock().await; if inner.contains_key(nick) { - return Err(anyhow::Error::msg("nick already registered: {nick}")); + return Err(anyhow::Error::msg(format!( + "nick already registered: {nick}" + ))); } inner.insert(nick.to_string(), sender); @@ -75,24 +75,27 @@ where { let mut socket = Framed::new(socket, LinesCodec::new()); if let Some(msg) = socket.next().await { - if let ClientMessage::Connect(nick) = serde_json::from_str(&msg?)? { - let (mut clt, tx) = ClientHandle::new(&nick, nick_set, socket); - match clt.nick_set.register(&clt.nick, tx).await { - Ok(()) => { - clt.send(ServerMessage::Success).await?; - let connect = format!("{} has joined the channel", clt.nick); - clt.publish(&connect).await?; - tokio::spawn(async move { clt.handle_client().await }); - } - Err(err) => { - clt.send(ServerMessage::Error(err.to_string())).await?; + match serde_json::from_str(&msg?)? { + ClientMessage::Connect(nick) => { + let (mut clt, tx) = ClientHandle::new(&nick, nick_set, socket); + match clt.nick_set.register(&clt.nick, tx).await { + Ok(()) => { + clt.send(ServerMessage::Connected).await?; + clt.publish(ServerMessage::Join(clt.nick.clone())).await?; + eprintln!("User has joined the channel: {}", clt.nick); + tokio::spawn(async move { clt.handle_client().await }); + } + Err(err) => { + clt.send(ServerMessage::Error(err.to_string())).await?; + } } } + msg => { + return Err(anyhow::Error::msg(format!( + "expected nick registration, received: {msg:?}" + ))); + } } - - return Err(anyhow::Error::msg( - "expected nick registration, received: {:?}", - )); } Ok(()) @@ -102,15 +105,19 @@ struct ClientHandle { nick: String, nick_set: Arc, socket: ChatFrame, - receiver: Receiver, + receiver: Receiver, } impl ClientHandle where T: AsyncRead + AsyncWrite + Unpin, { - fn new(nick: &str, nick_set: Arc, socket: ChatFrame) -> (Self, Sender) { - let (tx, rx) = mpsc::channel::(16); + fn new( + nick: &str, + nick_set: Arc, + socket: ChatFrame, + ) -> (Self, Sender) { + let (tx, rx) = mpsc::channel::(16); let clt = ClientHandle { nick: nick.to_string(), nick_set, @@ -128,14 +135,13 @@ where } /// Loops through the list of other clients to send them the message - async fn publish(&mut self, msg: &str) -> Result<()> { + async fn publish(&mut self, msg: ServerMessage) -> Result<()> { let mut inner = self.nick_set.inner.lock().await; for (nick, sender) in inner.iter_mut() { if *nick == self.nick { continue; } - - sender.send(msg.to_string()).await?; + sender.send(msg.clone()).await?; } Ok(()) } @@ -147,22 +153,23 @@ where if let Some(incoming) = incoming { match serde_json::from_str(&incoming?)? { ClientMessage::SendMsg(incoming) => { - self.publish(&incoming).await?; + eprintln!("Message received - {}: {}", self.nick, incoming); + let msg = ServerMessage::Message(self.nick.clone(), incoming); + self.publish(msg).await?; } ClientMessage::Leave => { + eprintln!("User has left the channel: {}", self.nick); + self.publish(ServerMessage::Leave(self.nick.clone())).await?; self.nick_set.deregister(&self.nick).await; - self.send(ServerMessage::Success).await?; - let disconnect = format!("{} has disconnected", self.nick); - self.publish(&disconnect).await?; break; } - _ => todo!(), + _ => self.send(ServerMessage::Error("Unrecognized message".to_string())).await?, } } }, msg = self.receiver.recv() => { if let Some(msg) = msg { - self.send(ServerMessage::Message(msg)).await?; + self.send(msg).await?; } }, ); diff --git a/tests/integration_test.rs b/tests/integration_test.rs new file mode 100644 index 0000000..4d326e6 --- /dev/null +++ b/tests/integration_test.rs @@ -0,0 +1,32 @@ +#[tokio::test] +async fn connect_send() { + let addr = "127.0.0.1:8080"; + + // Start the server + tokio::spawn(async move { simple_chat::server::run(addr).await }); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let messages = std::sync::Arc::new(simple_chat::client::Messages::default()); + let inner_messages = messages.clone(); + + // Using a separate thread to prevent runtime-in-runtime issues, connect to the server + std::thread::spawn(move || { + simple_chat::client::Connection::connect("test", addr, inner_messages).unwrap(); + }); + + let messages2 = std::sync::Arc::new(simple_chat::client::Messages::default()); + let inner_messages2 = messages2.clone(); + std::thread::spawn(move || { + let mut conn = + simple_chat::client::Connection::connect("test2", addr, inner_messages2).unwrap(); + conn.send(simple_chat::ClientMessage::SendMsg( + "Test message".to_string(), + )) + .unwrap(); + conn.send(simple_chat::ClientMessage::Leave).unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + assert_eq!(dbg!(messages2.get()).len(), 2); + assert_eq!(dbg!(messages.get()).len(), 4); +} diff --git a/tests/main.rs b/tests/main.rs deleted file mode 100644 index 8b13789..0000000 --- a/tests/main.rs +++ /dev/null @@ -1 +0,0 @@ -