diff --git a/home-mixer/filters/retweet_deduplication_filter.rs b/home-mixer/filters/retweet_deduplication_filter.rs index 1216f5a..71c015a 100644 --- a/home-mixer/filters/retweet_deduplication_filter.rs +++ b/home-mixer/filters/retweet_deduplication_filter.rs @@ -31,8 +31,11 @@ impl Filter for RetweetDeduplicationFilter { } None => { // Mark this original tweet ID as seen so retweets of it get filtered - seen_tweet_ids.insert(candidate.tweet_id as u64); - kept.push(candidate); + if seen_tweet_ids.insert(candidate.tweet_id as u64) { + kept.push(candidate); + } else { + removed.push(candidate); + } } } } @@ -40,3 +43,70 @@ impl Filter for RetweetDeduplicationFilter { Ok(FilterResult { kept, removed }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::candidate_pipeline::candidate::PostCandidate; + use crate::candidate_pipeline::query::ScoredPostsQuery; + + #[tokio::test] + async fn test_retweet_deduplication_filter_removes_original_if_retweet_seen_first() { + let filter = RetweetDeduplicationFilter; + let query = ScoredPostsQuery::default(); + + let retweet_candidate = PostCandidate { + tweet_id: 100, + retweeted_tweet_id: Some(200), + ..Default::default() + }; + + let original_candidate = PostCandidate { + tweet_id: 200, + retweeted_tweet_id: None, + ..Default::default() + }; + + let candidates = vec![retweet_candidate, original_candidate]; + + let result = filter.filter(&query, candidates).await.unwrap(); + + // Should keep the retweet (first occurrence of 200) + assert_eq!(result.kept.len(), 1); + assert_eq!(result.kept[0].tweet_id, 100); + + // Should remove the original tweet (second occurrence of 200) + assert_eq!(result.removed.len(), 1); + assert_eq!(result.removed[0].tweet_id, 200); + } + + #[tokio::test] + async fn test_retweet_deduplication_filter_removes_retweet_if_original_seen_first() { + let filter = RetweetDeduplicationFilter; + let query = ScoredPostsQuery::default(); + + let original_candidate = PostCandidate { + tweet_id: 200, + retweeted_tweet_id: None, + ..Default::default() + }; + + let retweet_candidate = PostCandidate { + tweet_id: 100, + retweeted_tweet_id: Some(200), + ..Default::default() + }; + + let candidates = vec![original_candidate, retweet_candidate]; + + let result = filter.filter(&query, candidates).await.unwrap(); + + // Should keep the original tweet (first occurrence of 200) + assert_eq!(result.kept.len(), 1); + assert_eq!(result.kept[0].tweet_id, 200); + + // Should remove the retweet (second occurrence of 200) + assert_eq!(result.removed.len(), 1); + assert_eq!(result.removed[0].tweet_id, 100); + } +}