diff --git a/.gitignore b/.gitignore index 6985cf1..196e176 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,8 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + + +# Added by cargo + +/target diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..05eff78 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "simple-chat" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "chat-svr" + +[[bin]] +name = "chat-clt" + +[dependencies] +anyhow = { version = "1.0.86", default-features = false, features = ["std"] } +clap = { version = "4.5.8", default-features = false, features = ["std", "derive"] } +futures = { version = "0.3.30", default-features = false } +futures-util = { version = "0.3.30", default-features = false, features = ["sink"] } +tokio = { version = "1.38.0", default-features = false, features = ["full"] } +tokio-stream = { version = "0.1.15", default-features = false, features = ["net"] } +tokio-util = { version = "0.7.11", default-features = false, features = ["codec"] } +ratatui = { version = "0.27.0", default-features = false, features = ["crossterm"] } +regex = { version = "1.10.5", default-features = false, features = ["unicode-perl"] } +serde = { version = "1.0.194", default-features = false, features = ["derive"] } +serde_json = { version = "1.0.120", default-features = false, features = ["std"] } +tui-input = { version = "0.9.0", default-features = false, features = ["crossterm"] } + +[workspace.lints.clippy] +wildcard_imports = "deny" diff --git a/src/bin/chat-clt.rs b/src/bin/chat-clt.rs new file mode 100644 index 0000000..29cf8eb --- /dev/null +++ b/src/bin/chat-clt.rs @@ -0,0 +1,227 @@ +use std::io::{stdout, Stdout}; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use ratatui::{ + backend::CrosstermBackend, + crossterm::{ + event::{self, KeyCode}, + execute, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + }, + prelude::{Constraint, Layout}, + style::{Style, Stylize}, + symbols::border, + text::{Line, Span, Text}, + widgets::{Block, List, ListItem, Paragraph}, + Frame, Terminal, +}; +use regex::Regex; +use tui_input::{backend::crossterm::EventHandler, Input}; + +use simple_chat::{ + client::{Connection, Messages}, + ClientMessage, +}; + +static REGEX: std::sync::OnceLock = std::sync::OnceLock::new(); + +type Tui = Terminal>; + +enum ClientCommand { + Connect(String, A), + Send(String), + Leave, +} + +impl core::str::FromStr for ClientCommand { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let input_re = REGEX.get_or_init(|| { + Regex::new( + r#"(?x) + ^(leave)\s*$ | + ^(connect)\s+(.*)@(.*)$ | + ^(send)\s+(.*) + "#, + ) + .unwrap() + }); + + let captures = input_re.captures(s).map(|captures| { + captures + .iter() + .skip(1) + .flatten() + .map(|c| c.as_str()) + .collect::>() + }); + + let ret = match captures.as_deref() { + Some(["leave"]) => Ok(Self::Leave), + Some(["connect", nick, addr]) => Ok(Self::Connect(nick.to_string(), addr.to_string())), + Some(["send", message]) => Ok(Self::Send(message.to_string())), + _ => { + anyhow::bail!("invalid command") + } + }; + + ret + } +} + +#[derive(Debug, Default)] +pub struct App { + input: Input, + messages: Arc, + connection: Option, + exit: bool, +} + +impl App { + fn run(&mut self, terminal: &mut Tui) -> Result<()> { + while !self.exit { + terminal.draw(|frame| ui(frame, self))?; + if event::poll(std::time::Duration::from_millis(16))? { + self.handle_events()?; + } + } + + Ok(()) + } + + fn handle_events(&mut self) -> Result<()> { + if let event::Event::Key(key) = event::read()? { + match key.code { + KeyCode::Enter => { + let input = self.input.value().to_string(); + match input.parse::>() { + Ok(ClientCommand::Leave) => self.exit(), + Ok(ClientCommand::Connect(nick, addr)) => { + if self.connection.is_none() { + self.messages + .push("INFO", &format!("Connecting to {addr} as {nick}")); + match Connection::connect(&nick, addr, self.messages.clone()) { + Ok(connection) => self.connection = Some(connection), + Err(e) => { + let err = format!("Failed to connect to server: {e}"); + self.messages.push("ERROR", &err); + } + } + } else { + self.messages.push("ERROR", "Already connected to server"); + } + } + Ok(ClientCommand::Send(message)) => { + if let Some(connection) = &mut self.connection { + self.messages.push(&connection.nick, &message); + if let Err(e) = connection.send(ClientMessage::SendMsg(message)) { + let err = format!("Failed to send message: {e}"); + self.messages.push("ERROR", &err); + } + } else { + self.messages.push("ERROR", "Not connected to a server"); + } + } + _ => { + self.messages.push("ERROR", "Invalid command"); + } + } + self.input.reset(); + } + _ => { + self.input.handle_event(&event::Event::Key(key)); + } + } + } + + Ok(()) + } + + fn exit(&mut self) { + self.exit = true; + if let Some(connection) = &mut self.connection { + connection + .send(ClientMessage::Leave) + .or_else(|e| { + eprintln!("Failed to send disconnect message: {e:?}"); + anyhow::Ok(()) + }) + .context("should always return Ok") + .unwrap(); + } + } +} + +fn ui(f: &mut Frame, app: &App) { + let vertical = Layout::vertical([ + Constraint::Min(1), + Constraint::Length(3), + Constraint::Length(1), + ]); + let [message_area, input_area, help_area] = vertical.areas(f.size()); + + let header = Text::from(Line::from(vec![ + "Type ".into(), + "connect @:".bold().green(), + " to connect to a server, type ".into(), + "leave".bold().red(), + " to exit.".into(), + ])) + .patch_style(Style::default()); + let help_message = Paragraph::new(header); + f.render_widget(help_message, help_area); + + let width = input_area.width.max(3) - 3; + let scroll = app.input.visual_scroll(width as usize); + let input_block = Block::bordered().title(" Input ").border_set(border::THICK); + let input = Paragraph::new(app.input.value()) + .scroll((0, scroll as u16)) + .block(input_block); + f.render_widget(input, input_area); + + f.set_cursor( + input_area.x + ((app.input.visual_cursor()).max(scroll) - scroll) as u16 + 1, + input_area.y + 1, + ); + + let message_block = Block::bordered() + .title(" Messages ") + .border_set(border::THICK); + let messages: Vec = app + .messages + .get() + .iter() + .map(|(n, m)| { + let content = vec![Line::from(Span::raw(format!("{n}: {m}")))]; + ListItem::new(content) + }) + .collect(); + let messages = List::new(messages).block(message_block); + f.render_widget(messages, message_area); +} + +fn main() -> Result<()> { + let mut terminal = init()?; + terminal.clear()?; + + let mut app = App::default(); + let res = app.run(&mut terminal); + + restore()?; + + res +} + +fn init() -> Result { + execute!(stdout(), EnterAlternateScreen)?; + enable_raw_mode()?; + Terminal::new(CrosstermBackend::new(stdout())).map_err(anyhow::Error::from) +} + +fn restore() -> Result<()> { + execute!(stdout(), LeaveAlternateScreen)?; + disable_raw_mode()?; + Ok(()) +} diff --git a/src/bin/chat-svr.rs b/src/bin/chat-svr.rs new file mode 100644 index 0000000..c0028ad --- /dev/null +++ b/src/bin/chat-svr.rs @@ -0,0 +1,22 @@ +use anyhow::Result; +use clap::Parser; + +#[derive(Parser)] +#[command(version, about, long_about = None)] +struct Cli { + /// Set the server port to listen on. Defaults to `8080`. + #[arg(short)] + port: Option, +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::parse(); + let port = if let Some(port) = cli.port { + port + } else { + 8080 + }; + let addr = format!("127.0.0.1:{port}"); + simple_chat::server::run(addr).await +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..a2ebce6 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use anyhow::{Context, Result}; +use futures_util::sink::SinkExt; +use tokio::{ + net::{tcp::OwnedWriteHalf, TcpStream, ToSocketAddrs}, + runtime::Runtime, +}; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec}; + +use crate::{ClientMessage, ServerMessage}; + +#[derive(Debug, Default)] +pub struct Messages { + messages: std::sync::Mutex>, +} + +impl Messages { + pub fn push(&self, nick: &str, msg: &str) { + self.messages + .lock() + .unwrap() + .push((nick.to_string(), msg.to_string())) + } + + pub fn get(&self) -> Vec<(String, String)> { + self.messages.lock().unwrap().clone() + } +} + +#[derive(Debug)] +pub struct Connection { + pub nick: String, + sender: FramedWrite, + rt: Arc, +} + +impl Connection { + pub fn connect(nick: &str, addr: A, messages: Arc) -> Result { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + let socket = rt.block_on(async { TcpStream::connect(addr).await })?; + let (receiver, sender) = socket.into_split(); + let mut receiver = FramedRead::new(receiver, LinesCodec::new()); + let sender = FramedWrite::new(sender, LinesCodec::new()); + let mut conn = Connection { + nick: nick.to_string(), + sender, + rt: Arc::new(rt), + }; + + conn.send(ClientMessage::Connect(conn.nick.clone()))?; + let reply = conn + .rt + .block_on(async { receiver.next().await }) + .ok_or(anyhow::Error::msg("error"))?; + match serde_json::from_str(&reply?)? { + ServerMessage::Connected => { + messages.push("INFO", "Connected to server."); + } + ServerMessage::Error(err) => { + return Err(anyhow::Error::msg(format!( + "failed to connect to server: {err}" + ))); + } + _ => { + return Err(anyhow::Error::msg( + "received unexpected message".to_string(), + )); + } + } + + let rt_inner = conn.rt.clone(); + std::thread::spawn(move || { + let listener_task = rt_inner.spawn(async move { + while let Some(incoming) = receiver.next().await { + let message = serde_json::from_str(&incoming?)?; + match message { + ServerMessage::Message(nick, msg) => { + messages.push(&nick, &msg); + } + ServerMessage::Error(err) => { + messages.push("ERROR", &err); + } + ServerMessage::Join(nick) => { + messages.push("SERVER", &format!("{nick} has joined the channel.")); + } + ServerMessage::Leave(nick) => { + messages.push("SERVER", &format!("{nick} has left the channel.")); + } + _ => { + // We don't care about any other message types in here + continue; + } + } + } + messages.push("INFO", "Server has closed the connection"); + anyhow::Ok(()) + }); + + rt_inner.block_on(listener_task)??; + anyhow::Ok(()) + }); + + Ok(conn) + } + + pub fn send(&mut self, msg: ClientMessage) -> Result<()> { + let msg_json = serde_json::to_string(&msg)?; + self.rt + .block_on(async { self.sender.send(msg_json).await }) + .context("sending message to server") + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..af6a783 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,23 @@ +use serde::{Deserialize, Serialize}; +use tokio_util::codec::{Framed, LinesCodec}; + +pub mod client; +pub mod server; + +type ChatFrame = Framed; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ServerMessage { + Message(String, String), + Error(String), + Join(String), + Leave(String), + Connected, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ClientMessage { + Connect(String), + SendMsg(String), + Leave, +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..254c53c --- /dev/null +++ b/src/server.rs @@ -0,0 +1,208 @@ +use std::collections::HashMap; +use std::marker::Unpin; +use std::sync::Arc; + +use anyhow::Result; +use futures_util::sink::SinkExt; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpListener, ToSocketAddrs}, + sync::{ + mpsc::{self, Receiver, Sender}, + Mutex, + }, +}; +use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; +use tokio_util::codec::{Framed, LinesCodec}; + +use crate::{ChatFrame, ClientMessage, ServerMessage}; + +/// Implementation Note: Originally tried to use the tokio::sync::broadcast +/// channel type, but the requirement is not to send the message to the sender. +/// Similarly, trying to store the actual network handle inside this set +/// results in some bad borrowing failures. +struct ClientSet { + inner: Mutex>>, +} + +impl ClientSet { + fn new() -> Self { + Self { + inner: Mutex::new(HashMap::new()), + } + } + + /// Registers the nick if it is not already in use + async fn register(&self, nick: &str, sender: Sender) -> Result<()> { + let mut inner = self.inner.lock().await; + if inner.contains_key(nick) { + return Err(anyhow::Error::msg(format!( + "nick already registered: {nick}" + ))); + } + + inner.insert(nick.to_string(), sender); + + Ok(()) + } + + /// De-registers the nick, a nick that is not registered is not remov + async fn deregister(&self, nick: &str) { + let mut inner = self.inner.lock().await; + inner.remove(nick); + } +} + +/// Start the server to listen for incoming connections +pub async fn run(addr: A) -> Result<()> { + let listener = TcpListener::bind(addr).await?; + let mut listener = TcpListenerStream::new(listener); + + let nick_set = Arc::new(ClientSet::new()); + + while let Some(socket) = listener.next().await { + if let Ok(socket) = socket { + handle_incoming(socket, nick_set.clone()).await?; + } + } + + Ok(()) +} + +async fn handle_incoming(socket: S, nick_set: Arc) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let mut socket = Framed::new(socket, LinesCodec::new()); + if let Some(msg) = socket.next().await { + match serde_json::from_str(&msg?)? { + ClientMessage::Connect(nick) => { + let (mut clt, tx) = ClientHandle::new(&nick, nick_set, socket); + match clt.nick_set.register(&clt.nick, tx).await { + Ok(()) => { + clt.send(ServerMessage::Connected).await?; + clt.publish(ServerMessage::Join(clt.nick.clone())).await?; + eprintln!("User has joined the channel: {}", clt.nick); + tokio::spawn(async move { clt.handle_client().await }); + } + Err(err) => { + clt.send(ServerMessage::Error(err.to_string())).await?; + } + } + } + msg => { + return Err(anyhow::Error::msg(format!( + "expected nick registration, received: {msg:?}" + ))); + } + } + } + + Ok(()) +} + +struct ClientHandle { + nick: String, + nick_set: Arc, + socket: ChatFrame, + receiver: Receiver, +} + +impl ClientHandle +where + T: AsyncRead + AsyncWrite + Unpin, +{ + fn new( + nick: &str, + nick_set: Arc, + socket: ChatFrame, + ) -> (Self, Sender) { + let (tx, rx) = mpsc::channel::(16); + let clt = ClientHandle { + nick: nick.to_string(), + nick_set, + socket, + receiver: rx, + }; + + (clt, tx) + } + + async fn send(&mut self, msg: ServerMessage) -> Result<()> { + let msg_json = serde_json::to_string(&msg)?; + self.socket.send(msg_json).await?; + Ok(()) + } + + /// Loops through the list of other clients to send them the message + async fn publish(&mut self, msg: ServerMessage) -> Result<()> { + let mut inner = self.nick_set.inner.lock().await; + for (nick, sender) in inner.iter_mut() { + if *nick == self.nick { + continue; + } + sender.send(msg.clone()).await?; + } + Ok(()) + } + + async fn handle_client(mut self) -> Result<()> { + loop { + tokio::select!( + incoming = self.socket.next() => { + if let Some(incoming) = incoming { + match serde_json::from_str(&incoming?)? { + ClientMessage::SendMsg(incoming) => { + eprintln!("Message received - {}: {}", self.nick, incoming); + let msg = ServerMessage::Message(self.nick.clone(), incoming); + self.publish(msg).await?; + } + ClientMessage::Leave => { + eprintln!("User has left the channel: {}", self.nick); + self.publish(ServerMessage::Leave(self.nick.clone())).await?; + self.nick_set.deregister(&self.nick).await; + break; + } + _ => self.send(ServerMessage::Error("Unrecognized message".to_string())).await?, + } + } + }, + msg = self.receiver.recv() => { + if let Some(msg) = msg { + self.send(msg).await?; + } + }, + ); + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + #[tokio::test] + async fn register_twice() { + use tokio::sync::mpsc; + let (tx, _) = mpsc::channel(16); + let nick_set = std::sync::Arc::new(crate::server::ClientSet::new()); + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap(); + + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap_err(); + } + + #[tokio::test] + async fn register_deregister() { + use tokio::sync::mpsc; + let (tx, _) = mpsc::channel(16); + let nick_set = std::sync::Arc::new(crate::server::ClientSet::new()); + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap(); + + let nick = "Nick1"; + nick_set.register(nick, tx.clone()).await.unwrap_err(); + nick_set.deregister(nick).await; + nick_set.register(nick, tx.clone()).await.unwrap(); + } +} diff --git a/tests/integration_test.rs b/tests/integration_test.rs new file mode 100644 index 0000000..4d326e6 --- /dev/null +++ b/tests/integration_test.rs @@ -0,0 +1,32 @@ +#[tokio::test] +async fn connect_send() { + let addr = "127.0.0.1:8080"; + + // Start the server + tokio::spawn(async move { simple_chat::server::run(addr).await }); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let messages = std::sync::Arc::new(simple_chat::client::Messages::default()); + let inner_messages = messages.clone(); + + // Using a separate thread to prevent runtime-in-runtime issues, connect to the server + std::thread::spawn(move || { + simple_chat::client::Connection::connect("test", addr, inner_messages).unwrap(); + }); + + let messages2 = std::sync::Arc::new(simple_chat::client::Messages::default()); + let inner_messages2 = messages2.clone(); + std::thread::spawn(move || { + let mut conn = + simple_chat::client::Connection::connect("test2", addr, inner_messages2).unwrap(); + conn.send(simple_chat::ClientMessage::SendMsg( + "Test message".to_string(), + )) + .unwrap(); + conn.send(simple_chat::ClientMessage::Leave).unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + assert_eq!(dbg!(messages2.get()).len(), 2); + assert_eq!(dbg!(messages.get()).len(), 4); +}