diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e48b341 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,88 @@ +name: CI + +on: + push: + branches: ["*"] + pull_request: + branches: ["*"] + +env: + CARGO_TERM_COLOR: always + +jobs: + check: + name: Build, Lint & Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust stable + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - name: Cache cargo registry & build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + + - name: Check formatting + run: cargo fmt --check + + - name: Clippy + run: cargo clippy -- -D warnings + + - name: Build + run: cargo build + + - name: Run tests + run: cargo test + + integration: + name: Server-Client Integration + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust stable + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry & build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + + - name: Build server and client + run: cargo build + + - name: Start server and send a message from client + run: | + # Start the server in the background + CHAT_HOST=127.0.0.1 CHAT_PORT=9090 cargo run --bin server & + SERVER_PID=$! + sleep 2 + + # Use a simple script to connect, send a message, and leave + ( + echo '{"type":"Join","username":"ci-user"}' + sleep 1 + echo '{"type":"Send","content":"Hello from CI!"}' + sleep 1 + echo '{"type":"Leave"}' + sleep 1 + ) | nc 127.0.0.1 9090 + + # Give the server a moment to process + sleep 1 + + # Kill the server + kill $SERVER_PID 2>/dev/null || true + echo "Integration smoke test passed!" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..90ade83 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,3 @@ +[workspace] +members = ["protocol", "server", "client"] +resolver = "2" diff --git a/client/Cargo.toml b/client/Cargo.toml new file mode 100644 index 0000000..a1a3048 --- /dev/null +++ b/client/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "client" +version = "0.1.0" +edition = "2021" + +[dependencies] +protocol = { path = "../protocol" } +tokio = { version = "1", features = ["full"] } +clap = { version = "4", features = ["derive", "env"] } diff --git a/client/src/main.rs b/client/src/main.rs new file mode 100644 index 0000000..859a345 --- /dev/null +++ b/client/src/main.rs @@ -0,0 +1,134 @@ +use clap::Parser; +use protocol::{ClientMessage, ServerMessage}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; + +#[derive(Parser)] +#[command(name = "chat-client", about = "Simple async chat client")] +struct Args { + /// Server host + #[arg(long, env = "CHAT_HOST", default_value = "127.0.0.1")] + host: String, + + /// Server port + #[arg(long, env = "CHAT_PORT", default_value = "8080")] + port: u16, + + /// Username for the chat + #[arg(long, env = "CHAT_USERNAME")] + username: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + let addr = format!("{}:{}", args.host, args.port); + + let stream = TcpStream::connect(&addr).await?; + println!("Connected to {addr}"); + + let (reader, mut writer) = stream.into_split(); + let mut reader = BufReader::new(reader); + + // Send join message + let join_msg = protocol::encode(&ClientMessage::Join { + username: args.username.clone(), + })?; + writer.write_all(join_msg.as_bytes()).await?; + + // Read welcome/error response + let mut line = String::new(); + reader.read_line(&mut line).await?; + match protocol::decode::(&line)? { + ServerMessage::Welcome { message } => println!("{message}"), + ServerMessage::Error { message } => { + eprintln!("Error: {message}"); + return Ok(()); + } + _ => {} + } + + // Spawn task to read server messages + let recv_task = tokio::spawn(async move { + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => { + println!("Server disconnected."); + break; + } + Ok(_) => { + if let Ok(msg) = protocol::decode::(&line) { + match msg { + ServerMessage::Chat { username, content } => { + println!("[{username}]: {content}"); + } + ServerMessage::UserJoined { username } => { + println!("* {username} joined the chat"); + } + ServerMessage::UserLeft { username } => { + println!("* {username} left the chat"); + } + ServerMessage::Error { message } => { + eprintln!("Server error: {message}"); + } + ServerMessage::Welcome { message } => { + println!("{message}"); + } + } + } + } + Err(e) => { + eprintln!("Read error: {e}"); + break; + } + } + } + }); + + // Read user input from stdin + let stdin = tokio::io::stdin(); + let mut stdin_reader = BufReader::new(stdin); + let mut input = String::new(); + + loop { + input.clear(); + match stdin_reader.read_line(&mut input).await { + Ok(0) => break, + Ok(_) => { + let trimmed = input.trim(); + if trimmed.is_empty() { + continue; + } + + if trimmed == "leave" { + let leave_msg = protocol::encode(&ClientMessage::Leave)?; + writer.write_all(leave_msg.as_bytes()).await?; + println!("Disconnected from chat."); + break; + } else if let Some(msg) = trimmed.strip_prefix("send ") { + if msg.is_empty() { + println!("Usage: send "); + continue; + } + let send_msg = protocol::encode(&ClientMessage::Send { + content: msg.to_string(), + })?; + writer.write_all(send_msg.as_bytes()).await?; + } else { + println!("Unknown command. Available commands:"); + println!(" send - Send a message to the chat"); + println!(" leave - Disconnect and exit"); + } + } + Err(e) => { + eprintln!("Input error: {e}"); + break; + } + } + } + + recv_task.abort(); + Ok(()) +} diff --git a/hooks/pre-commit b/hooks/pre-commit new file mode 100644 index 0000000..1202529 --- /dev/null +++ b/hooks/pre-commit @@ -0,0 +1,28 @@ +#!/bin/sh +# +# Pre-commit hook: ensures code is formatted, compiles, and passes clippy. + +set -e + +echo "Running cargo fmt --check..." +cargo fmt --check +if [ $? -ne 0 ]; then + echo "Error: Code is not formatted. Run 'cargo fmt' before committing." + exit 1 +fi + +echo "Running cargo clippy..." +cargo clippy -- -D warnings +if [ $? -ne 0 ]; then + echo "Error: Clippy found warnings. Fix them before committing." + exit 1 +fi + +echo "Running cargo build..." +cargo build +if [ $? -ne 0 ]; then + echo "Error: Build failed. Fix compilation errors before committing." + exit 1 +fi + +echo "All pre-commit checks passed!" diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml new file mode 100644 index 0000000..f42ceb2 --- /dev/null +++ b/protocol/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "protocol" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs new file mode 100644 index 0000000..b91279e --- /dev/null +++ b/protocol/src/lib.rs @@ -0,0 +1,104 @@ +use serde::{Deserialize, Serialize}; + +/// Messages sent from client to server. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum ClientMessage { + Join { username: String }, + Leave, + Send { content: String }, +} + +/// Messages sent from server to client. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum ServerMessage { + Welcome { message: String }, + Error { message: String }, + UserJoined { username: String }, + UserLeft { username: String }, + Chat { username: String, content: String }, +} + +/// Encode a message as a newline-delimited JSON string. +pub fn encode(msg: &T) -> Result { + let mut s = serde_json::to_string(msg)?; + s.push('\n'); + Ok(s) +} + +/// Decode a message from a JSON string. +pub fn decode<'a, T: Deserialize<'a>>(s: &'a str) -> Result { + serde_json::from_str(s.trim()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_decode_join() { + let msg = ClientMessage::Join { + username: "alice".to_string(), + }; + let encoded = encode(&msg).unwrap(); + assert!(encoded.ends_with('\n')); + let decoded: ClientMessage = decode(&encoded).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn test_encode_decode_send() { + let msg = ClientMessage::Send { + content: "hello world".to_string(), + }; + let encoded = encode(&msg).unwrap(); + let decoded: ClientMessage = decode(&encoded).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn test_encode_decode_leave() { + let msg = ClientMessage::Leave; + let encoded = encode(&msg).unwrap(); + let decoded: ClientMessage = decode(&encoded).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn test_encode_decode_server_chat() { + let msg = ServerMessage::Chat { + username: "bob".to_string(), + content: "hi there".to_string(), + }; + let encoded = encode(&msg).unwrap(); + let decoded: ServerMessage = decode(&encoded).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn test_encode_decode_server_welcome() { + let msg = ServerMessage::Welcome { + message: "Welcome!".to_string(), + }; + let encoded = encode(&msg).unwrap(); + let decoded: ServerMessage = decode(&encoded).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn test_encode_decode_server_error() { + let msg = ServerMessage::Error { + message: "username taken".to_string(), + }; + let encoded = encode(&msg).unwrap(); + let decoded: ServerMessage = decode(&encoded).unwrap(); + assert_eq!(decoded, msg); + } + + #[test] + fn test_decode_invalid_json() { + let result: Result = decode("not json"); + assert!(result.is_err()); + } +} diff --git a/server/Cargo.toml b/server/Cargo.toml new file mode 100644 index 0000000..46d6987 --- /dev/null +++ b/server/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "server" +version = "0.1.0" +edition = "2021" + +[lib] +name = "server_lib" +path = "src/lib.rs" + +[[bin]] +name = "server" +path = "src/main.rs" + +[dependencies] +protocol = { path = "../protocol" } +tokio = { version = "1", features = ["full"] } +tracing = "0.1" +tracing-subscriber = "0.3" diff --git a/server/src/lib.rs b/server/src/lib.rs new file mode 100644 index 0000000..addf7a5 --- /dev/null +++ b/server/src/lib.rs @@ -0,0 +1 @@ +pub mod room; diff --git a/server/src/main.rs b/server/src/main.rs new file mode 100644 index 0000000..f5db9eb --- /dev/null +++ b/server/src/main.rs @@ -0,0 +1,131 @@ +use server_lib::room::Room; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpListener; +use tracing::{error, info}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let host = std::env::var("CHAT_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()); + let port = std::env::var("CHAT_PORT").unwrap_or_else(|_| "8080".to_string()); + let addr = format!("{host}:{port}"); + + let listener = TcpListener::bind(&addr) + .await + .unwrap_or_else(|e| panic!("Failed to bind to {addr}: {e}")); + + info!("Chat server listening on {addr}"); + + let room = Arc::new(Room::new()); + + loop { + let (stream, peer_addr) = match listener.accept().await { + Ok(conn) => conn, + Err(e) => { + error!("Failed to accept connection: {e}"); + continue; + } + }; + + info!("New connection from {peer_addr}"); + let room = Arc::clone(&room); + + tokio::spawn(async move { + if let Err(e) = handle_connection(stream, room).await { + error!("Connection error from {peer_addr}: {e}"); + } + }); + } +} + +async fn handle_connection( + stream: tokio::net::TcpStream, + room: Arc, +) -> Result<(), Box> { + let (reader, mut writer) = stream.into_split(); + let mut reader = BufReader::new(reader); + let mut line = String::new(); + + // First message must be a Join + let n = reader.read_line(&mut line).await?; + if n == 0 { + return Ok(()); + } + + let username = match protocol::decode::(&line)? { + protocol::ClientMessage::Join { username } => username, + _ => { + let err = protocol::encode(&protocol::ServerMessage::Error { + message: "First message must be a Join".to_string(), + })?; + writer.write_all(err.as_bytes()).await?; + return Ok(()); + } + }; + + let mut rx = match room.join(&username) { + Ok(rx) => rx, + Err(msg) => { + let err = protocol::encode(&protocol::ServerMessage::Error { message: msg })?; + writer.write_all(err.as_bytes()).await?; + return Ok(()); + } + }; + + let welcome = protocol::encode(&protocol::ServerMessage::Welcome { + message: format!("Welcome, {username}!"), + })?; + writer.write_all(welcome.as_bytes()).await?; + + info!("{username} joined the chat"); + + let write_username = username.clone(); + let write_task = tokio::spawn(async move { + while let Ok(msg) = rx.recv().await { + let encoded = match protocol::encode(&msg) { + Ok(e) => e, + Err(_) => continue, + }; + if writer.write_all(encoded.as_bytes()).await.is_err() { + break; + } + } + write_username + }); + + loop { + line.clear(); + let n = reader.read_line(&mut line).await; + + match n { + Ok(0) | Err(_) => { + // Client disconnected + break; + } + Ok(_) => {} + } + + match protocol::decode::(&line) { + Ok(protocol::ClientMessage::Send { content }) => { + room.broadcast(&username, content); + } + Ok(protocol::ClientMessage::Leave) => { + break; + } + Ok(protocol::ClientMessage::Join { .. }) => { + // Ignore duplicate join + } + Err(e) => { + error!("Failed to decode message from {username}: {e}"); + } + } + } + + room.leave(&username); + info!("{username} left the chat"); + write_task.abort(); + + Ok(()) +} diff --git a/server/src/room.rs b/server/src/room.rs new file mode 100644 index 0000000..8694276 --- /dev/null +++ b/server/src/room.rs @@ -0,0 +1,174 @@ +use protocol::ServerMessage; +use std::collections::HashMap; +use std::sync::Mutex; +use tokio::sync::broadcast; + +const CHANNEL_CAPACITY: usize = 256; + +/// A chat room that manages connected users and message broadcasting. +pub struct Room { + state: Mutex, +} + +struct RoomState { + users: HashMap>, +} + +impl Default for Room { + fn default() -> Self { + Self::new() + } +} + +impl Room { + pub fn new() -> Self { + Self { + state: Mutex::new(RoomState { + users: HashMap::new(), + }), + } + } + + /// Add a user to the room. Returns a receiver for incoming messages, + /// or an error string if the username is taken. + pub fn join(&self, username: &str) -> Result, String> { + let mut state = self.state.lock().unwrap(); + + if state.users.contains_key(username) { + return Err(format!("Username '{username}' is already taken")); + } + + let (tx, rx) = broadcast::channel(CHANNEL_CAPACITY); + state.users.insert(username.to_string(), tx); + + // Notify all other users + let join_msg = ServerMessage::UserJoined { + username: username.to_string(), + }; + for (name, sender) in &state.users { + if name != username { + let _ = sender.send(join_msg.clone()); + } + } + + Ok(rx) + } + + /// Remove a user from the room and notify others. + pub fn leave(&self, username: &str) { + let mut state = self.state.lock().unwrap(); + state.users.remove(username); + + let leave_msg = ServerMessage::UserLeft { + username: username.to_string(), + }; + for sender in state.users.values() { + let _ = sender.send(leave_msg.clone()); + } + } + + /// Broadcast a chat message to all users except the sender. + pub fn broadcast(&self, sender_username: &str, content: String) { + let state = self.state.lock().unwrap(); + + let chat_msg = ServerMessage::Chat { + username: sender_username.to_string(), + content, + }; + + for (name, sender) in &state.users { + if name != sender_username { + let _ = sender.send(chat_msg.clone()); + } + } + } + + /// Returns the number of connected users. + #[cfg(test)] + pub fn user_count(&self) -> usize { + self.state.lock().unwrap().users.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_join_and_leave() { + let room = Room::new(); + assert_eq!(room.user_count(), 0); + + let _rx = room.join("alice").unwrap(); + assert_eq!(room.user_count(), 1); + + let _rx2 = room.join("bob").unwrap(); + assert_eq!(room.user_count(), 2); + + room.leave("alice"); + assert_eq!(room.user_count(), 1); + + room.leave("bob"); + assert_eq!(room.user_count(), 0); + } + + #[test] + fn test_duplicate_username_rejected() { + let room = Room::new(); + let _rx = room.join("alice").unwrap(); + let result = room.join("alice"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("already taken")); + } + + #[tokio::test] + async fn test_broadcast_excludes_sender() { + let room = Room::new(); + let _rx_alice = room.join("alice").unwrap(); + let mut rx_bob = room.join("bob").unwrap(); + + room.broadcast("alice", "hello".to_string()); + + let msg = rx_bob.recv().await.unwrap(); + match msg { + ServerMessage::Chat { username, content } => { + assert_eq!(username, "alice"); + assert_eq!(content, "hello"); + } + _ => panic!("Expected Chat message"), + } + } + + #[tokio::test] + async fn test_join_notification() { + let room = Room::new(); + let mut rx_alice = room.join("alice").unwrap(); + + let _rx_bob = room.join("bob").unwrap(); + + let msg = rx_alice.recv().await.unwrap(); + match msg { + ServerMessage::UserJoined { username } => { + assert_eq!(username, "bob"); + } + _ => panic!("Expected UserJoined message"), + } + } + + #[tokio::test] + async fn test_leave_notification() { + let room = Room::new(); + let _rx_alice = room.join("alice").unwrap(); + let mut rx_bob = room.join("bob").unwrap(); + + room.leave("alice"); + + let msg = rx_bob.recv().await.unwrap(); + match msg { + ServerMessage::UserLeft { username } => { + assert_eq!(username, "alice"); + } + _ => panic!("Expected UserLeft message"), + } + } +} diff --git a/server/tests/integration_test.rs b/server/tests/integration_test.rs new file mode 100644 index 0000000..8d84eb9 --- /dev/null +++ b/server/tests/integration_test.rs @@ -0,0 +1,341 @@ +use std::time::Duration; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; +use tokio::time::sleep; + +use server_lib::room::Room; + +async fn start_server() -> u16 { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let room = std::sync::Arc::new(Room::new()); + loop { + let (stream, _) = listener.accept().await.unwrap(); + let room = std::sync::Arc::clone(&room); + tokio::spawn(handle_connection(stream, room)); + } + }); + + sleep(Duration::from_millis(50)).await; + port +} + +async fn handle_connection(stream: tokio::net::TcpStream, room: std::sync::Arc) { + let (reader, mut writer) = stream.into_split(); + let mut reader = BufReader::new(reader); + let mut line = String::new(); + + let n = reader.read_line(&mut line).await.unwrap(); + if n == 0 { + return; + } + + let username = match protocol::decode::(&line).unwrap() { + protocol::ClientMessage::Join { username } => username, + _ => return, + }; + + let mut rx = match room.join(&username) { + Ok(rx) => rx, + Err(msg) => { + let err = protocol::encode(&protocol::ServerMessage::Error { message: msg }).unwrap(); + writer.write_all(err.as_bytes()).await.unwrap(); + return; + } + }; + + let welcome = protocol::encode(&protocol::ServerMessage::Welcome { + message: format!("Welcome, {username}!"), + }) + .unwrap(); + writer.write_all(welcome.as_bytes()).await.unwrap(); + + let write_task = tokio::spawn(async move { + while let Ok(msg) = rx.recv().await { + let encoded = protocol::encode(&msg).unwrap(); + if writer.write_all(encoded.as_bytes()).await.is_err() { + break; + } + } + }); + + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) | Err(_) => break, + Ok(_) => {} + } + + match protocol::decode::(&line) { + Ok(protocol::ClientMessage::Send { content }) => { + room.broadcast(&username, content); + } + Ok(protocol::ClientMessage::Leave) => break, + _ => {} + } + } + + room.leave(&username); + write_task.abort(); +} + +async fn connect_and_join( + port: u16, + username: &str, +) -> ( + tokio::net::tcp::OwnedWriteHalf, + BufReader, +) { + let stream = TcpStream::connect(format!("127.0.0.1:{port}")) + .await + .unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = BufReader::new(reader); + + let join_msg = protocol::encode(&protocol::ClientMessage::Join { + username: username.to_string(), + }) + .unwrap(); + writer.write_all(join_msg.as_bytes()).await.unwrap(); + + let mut line = String::new(); + reader.read_line(&mut line).await.unwrap(); + let msg: protocol::ServerMessage = protocol::decode(&line).unwrap(); + match msg { + protocol::ServerMessage::Welcome { .. } => {} + protocol::ServerMessage::Error { message } => panic!("Join failed: {message}"), + other => panic!("Unexpected message: {other:?}"), + } + + (writer, reader) +} + +#[tokio::test] +async fn test_single_user_join() { + let port = start_server().await; + let (_writer, _reader) = connect_and_join(port, "alice").await; +} + +#[tokio::test] +async fn test_duplicate_username_rejected() { + let port = start_server().await; + let (_w1, _r1) = connect_and_join(port, "alice").await; + + let stream = TcpStream::connect(format!("127.0.0.1:{port}")) + .await + .unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = BufReader::new(reader); + + let join_msg = protocol::encode(&protocol::ClientMessage::Join { + username: "alice".to_string(), + }) + .unwrap(); + writer.write_all(join_msg.as_bytes()).await.unwrap(); + + let mut line = String::new(); + reader.read_line(&mut line).await.unwrap(); + let msg: protocol::ServerMessage = protocol::decode(&line).unwrap(); + match msg { + protocol::ServerMessage::Error { message } => { + assert!(message.contains("already taken")); + } + other => panic!("Expected error, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_message_broadcast() { + let port = start_server().await; + let (mut w_alice, _r_alice) = connect_and_join(port, "alice").await; + + sleep(Duration::from_millis(50)).await; + + let (_w_bob, mut r_bob) = connect_and_join(port, "bob").await; + + sleep(Duration::from_millis(50)).await; + + let send_msg = protocol::encode(&protocol::ClientMessage::Send { + content: "hello bob".to_string(), + }) + .unwrap(); + w_alice.write_all(send_msg.as_bytes()).await.unwrap(); + + let mut line = String::new(); + r_bob.read_line(&mut line).await.unwrap(); + let msg: protocol::ServerMessage = protocol::decode(&line).unwrap(); + match msg { + protocol::ServerMessage::Chat { username, content } => { + assert_eq!(username, "alice"); + assert_eq!(content, "hello bob"); + } + other => panic!("Expected Chat, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_leave_notification() { + let port = start_server().await; + let (mut w_alice, _r_alice) = connect_and_join(port, "alice").await; + + sleep(Duration::from_millis(50)).await; + + let (_w_bob, mut r_bob) = connect_and_join(port, "bob").await; + + sleep(Duration::from_millis(50)).await; + + let leave_msg = protocol::encode(&protocol::ClientMessage::Leave).unwrap(); + w_alice.write_all(leave_msg.as_bytes()).await.unwrap(); + + let mut line = String::new(); + r_bob.read_line(&mut line).await.unwrap(); + let msg: protocol::ServerMessage = protocol::decode(&line).unwrap(); + match msg { + protocol::ServerMessage::UserLeft { username } => { + assert_eq!(username, "alice"); + } + other => panic!("Expected UserLeft, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_join_notification() { + let port = start_server().await; + let (_w_alice, mut r_alice) = connect_and_join(port, "alice").await; + + sleep(Duration::from_millis(50)).await; + + let (_w_bob, _r_bob) = connect_and_join(port, "bob").await; + + let mut line = String::new(); + r_alice.read_line(&mut line).await.unwrap(); + let msg: protocol::ServerMessage = protocol::decode(&line).unwrap(); + match msg { + protocol::ServerMessage::UserJoined { username } => { + assert_eq!(username, "bob"); + } + other => panic!("Expected UserJoined, got: {other:?}"), + } +} + +#[tokio::test] +async fn test_sender_does_not_receive_own_message() { + let port = start_server().await; + let (mut w_alice, mut r_alice) = connect_and_join(port, "alice").await; + + sleep(Duration::from_millis(50)).await; + + let (_w_bob, mut r_bob) = connect_and_join(port, "bob").await; + + // Drain alice's join notification for bob + let mut line = String::new(); + r_alice.read_line(&mut line).await.unwrap(); + + sleep(Duration::from_millis(50)).await; + + let send_msg = protocol::encode(&protocol::ClientMessage::Send { + content: "test".to_string(), + }) + .unwrap(); + w_alice.write_all(send_msg.as_bytes()).await.unwrap(); + + let mut bob_line = String::new(); + r_bob.read_line(&mut bob_line).await.unwrap(); + let bob_msg: protocol::ServerMessage = protocol::decode(&bob_line).unwrap(); + assert!(matches!(bob_msg, protocol::ServerMessage::Chat { .. })); + + let result = tokio::time::timeout(Duration::from_millis(200), async { + let mut alice_line = String::new(); + r_alice.read_line(&mut alice_line).await + }) + .await; + + assert!(result.is_err(), "Alice should not receive her own message"); +} + +#[tokio::test] +async fn test_multiple_users_chat() { + let port = start_server().await; + let (mut w_alice, mut r_alice) = connect_and_join(port, "alice").await; + + sleep(Duration::from_millis(50)).await; + + let (mut w_bob, mut r_bob) = connect_and_join(port, "bob").await; + + // Drain alice's join notification + let mut line = String::new(); + r_alice.read_line(&mut line).await.unwrap(); + + sleep(Duration::from_millis(50)).await; + + let (_w_charlie, mut r_charlie) = connect_and_join(port, "charlie").await; + + // Drain join notifications + let mut a_line = String::new(); + r_alice.read_line(&mut a_line).await.unwrap(); + let mut b_line = String::new(); + r_bob.read_line(&mut b_line).await.unwrap(); + + sleep(Duration::from_millis(50)).await; + + // Alice sends a message + let send_msg = protocol::encode(&protocol::ClientMessage::Send { + content: "hi everyone".to_string(), + }) + .unwrap(); + w_alice.write_all(send_msg.as_bytes()).await.unwrap(); + + // Both bob and charlie receive it + let mut bob_line = String::new(); + r_bob.read_line(&mut bob_line).await.unwrap(); + let bob_msg: protocol::ServerMessage = protocol::decode(&bob_line).unwrap(); + match bob_msg { + protocol::ServerMessage::Chat { username, content } => { + assert_eq!(username, "alice"); + assert_eq!(content, "hi everyone"); + } + other => panic!("Expected Chat for bob, got: {other:?}"), + } + + let mut charlie_line = String::new(); + r_charlie.read_line(&mut charlie_line).await.unwrap(); + let charlie_msg: protocol::ServerMessage = protocol::decode(&charlie_line).unwrap(); + match charlie_msg { + protocol::ServerMessage::Chat { username, content } => { + assert_eq!(username, "alice"); + assert_eq!(content, "hi everyone"); + } + other => panic!("Expected Chat for charlie, got: {other:?}"), + } + + // Bob replies + let reply = protocol::encode(&protocol::ClientMessage::Send { + content: "hey alice".to_string(), + }) + .unwrap(); + w_bob.write_all(reply.as_bytes()).await.unwrap(); + + let mut a2 = String::new(); + r_alice.read_line(&mut a2).await.unwrap(); + let a2_msg: protocol::ServerMessage = protocol::decode(&a2).unwrap(); + match a2_msg { + protocol::ServerMessage::Chat { username, content } => { + assert_eq!(username, "bob"); + assert_eq!(content, "hey alice"); + } + other => panic!("Expected Chat for alice, got: {other:?}"), + } + + let mut c2 = String::new(); + r_charlie.read_line(&mut c2).await.unwrap(); + let c2_msg: protocol::ServerMessage = protocol::decode(&c2).unwrap(); + match c2_msg { + protocol::ServerMessage::Chat { username, content } => { + assert_eq!(username, "bob"); + assert_eq!(content, "hey alice"); + } + other => panic!("Expected Chat for charlie, got: {other:?}"), + } +}