|
| 1 | +use rocket::{ |
| 2 | + time::{Duration, OffsetDateTime}, |
| 3 | + tokio::{ |
| 4 | + sync::{oneshot, Mutex}, |
| 5 | + time::interval, |
| 6 | + }, |
| 7 | +}; |
| 8 | + |
| 9 | +use crate::error::{SessionError, SessionResult}; |
| 10 | + |
| 11 | +pub(super) const ID_COLUMN: &str = "id"; |
| 12 | +pub(super) const DATA_COLUMN: &str = "data"; |
| 13 | +pub(super) const EXPIRES_COLUMN: &str = "expires"; |
| 14 | + |
| 15 | +pub(super) struct SqlxBase<DB: sqlx::Database> { |
| 16 | + pool: sqlx::Pool<DB>, |
| 17 | + table_name: String, |
| 18 | + index_column: String, |
| 19 | +} |
| 20 | + |
| 21 | +impl<DB> SqlxBase<DB> |
| 22 | +where |
| 23 | + DB: sqlx::Database, |
| 24 | + for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>, |
| 25 | + for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>, |
| 26 | + OffsetDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 27 | +{ |
| 28 | + pub fn new(pool: sqlx::Pool<DB>, table_name: String, index_column: String) -> Self { |
| 29 | + SqlxBase { |
| 30 | + pool, |
| 31 | + table_name, |
| 32 | + index_column, |
| 33 | + } |
| 34 | + } |
| 35 | + |
| 36 | + pub async fn load<'a>( |
| 37 | + &self, |
| 38 | + id: String, |
| 39 | + ttl: Option<u32>, |
| 40 | + ) -> Result<Option<DB::Row>, sqlx::Error> |
| 41 | + where |
| 42 | + String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 43 | + { |
| 44 | + match ttl { |
| 45 | + Some(new_ttl) => { |
| 46 | + sqlx::query(&load_and_update_ttl_sql(&self.table_name)) |
| 47 | + .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) |
| 48 | + .bind(id) |
| 49 | + .fetch_optional(&self.pool) |
| 50 | + .await |
| 51 | + } |
| 52 | + None => { |
| 53 | + sqlx::query(&load_sql(&self.table_name)) |
| 54 | + .bind(id) |
| 55 | + .fetch_optional(&self.pool) |
| 56 | + .await |
| 57 | + } |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + pub async fn save<'a, V, I>( |
| 62 | + &'a self, |
| 63 | + id: String, |
| 64 | + value: V, |
| 65 | + index: I, |
| 66 | + ttl: u32, |
| 67 | + ) -> Result<(), sqlx::Error> |
| 68 | + where |
| 69 | + String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 70 | + V: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 71 | + I: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 72 | + { |
| 73 | + sqlx::query(&save_sql(&self.table_name, &self.index_column)) |
| 74 | + .bind(id) |
| 75 | + .bind(index) |
| 76 | + .bind(value) |
| 77 | + .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) |
| 78 | + .execute(&self.pool) |
| 79 | + .await?; |
| 80 | + Ok(()) |
| 81 | + } |
| 82 | + |
| 83 | + pub async fn delete<'a>(&self, id: String) -> Result<(), sqlx::Error> |
| 84 | + where |
| 85 | + String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 86 | + { |
| 87 | + sqlx::query(&delete_sql(&self.table_name)) |
| 88 | + .bind(id) |
| 89 | + .execute(&self.pool) |
| 90 | + .await?; |
| 91 | + Ok(()) |
| 92 | + } |
| 93 | +} |
| 94 | + |
| 95 | +/// Bind session ID |
| 96 | +pub(super) fn load_sql(table_name: &str) -> String { |
| 97 | + format!( |
| 98 | + "SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{table_name}\" \ |
| 99 | + WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP" |
| 100 | + ) |
| 101 | +} |
| 102 | + |
| 103 | +/// Bind expiration and session ID |
| 104 | +pub(super) fn load_and_update_ttl_sql(table_name: &str) -> String { |
| 105 | + format!( |
| 106 | + "UPDATE \"{table_name}\" SET {EXPIRES_COLUMN} = $1 \ |
| 107 | + WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP \ |
| 108 | + RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}", |
| 109 | + ) |
| 110 | +} |
| 111 | + |
| 112 | +/// Bind the session ID, index, data, and expiration |
| 113 | +pub(super) fn save_sql(table_name: &str, index_column: &str) -> String { |
| 114 | + format!( |
| 115 | + "INSERT INTO \"{table_name}\" ({ID_COLUMN}, {index_column}, {DATA_COLUMN}, {EXPIRES_COLUMN}) \ |
| 116 | + VALUES ($1, $2, $3, $4) \ |
| 117 | + ON CONFLICT ({ID_COLUMN}) DO UPDATE SET \ |
| 118 | + {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \ |
| 119 | + {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}" |
| 120 | + ) |
| 121 | +} |
| 122 | + |
| 123 | +/// Bind the session ID |
| 124 | +pub(super) fn delete_sql(table_name: &str) -> String { |
| 125 | + format!("DELETE FROM \"{table_name}\" WHERE {ID_COLUMN} = $1") |
| 126 | +} |
| 127 | + |
| 128 | +pub(super) fn expires_to_ttl(expires: OffsetDateTime) -> u32 { |
| 129 | + (expires - OffsetDateTime::now_utc()) |
| 130 | + .whole_seconds() |
| 131 | + .try_into() |
| 132 | + .unwrap_or(0) |
| 133 | +} |
| 134 | + |
| 135 | +/// Session cleanup task |
| 136 | +#[derive(Default)] |
| 137 | +pub(super) struct SqlxCleanupTask { |
| 138 | + interval: Option<std::time::Duration>, |
| 139 | + shutdown_tx: Mutex<Option<oneshot::Sender<u8>>>, |
| 140 | + table_name: String, |
| 141 | +} |
| 142 | + |
| 143 | +impl SqlxCleanupTask { |
| 144 | + pub(super) fn new(cleanup_interval: Option<std::time::Duration>, table_name: &str) -> Self { |
| 145 | + Self { |
| 146 | + interval: cleanup_interval, |
| 147 | + shutdown_tx: Mutex::default(), |
| 148 | + table_name: table_name.to_string(), |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + pub(super) async fn setup<DB>(&self, pool: &sqlx::Pool<DB>) -> SessionResult<()> |
| 153 | + where |
| 154 | + DB: sqlx::Database, |
| 155 | + for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>, |
| 156 | + for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>, |
| 157 | + OffsetDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>, |
| 158 | + { |
| 159 | + let Some(cleanup_interval) = self.interval else { |
| 160 | + return Ok(()); |
| 161 | + }; |
| 162 | + |
| 163 | + let (tx, mut rx) = oneshot::channel(); |
| 164 | + self.shutdown_tx.lock().await.replace(tx); |
| 165 | + |
| 166 | + let pool = pool.clone(); |
| 167 | + let table_name = self.table_name.clone(); |
| 168 | + rocket::tokio::spawn(async move { |
| 169 | + rocket::info!("Starting session cleanup monitor"); |
| 170 | + let mut interval = interval(cleanup_interval); |
| 171 | + loop { |
| 172 | + rocket::tokio::select! { |
| 173 | + _ = interval.tick() => { |
| 174 | + rocket::debug!("Cleaning up expired sessions"); |
| 175 | + if let Err(e) = sqlx::query(&format!( |
| 176 | + "DELETE FROM \"{table_name}\" WHERE {EXPIRES_COLUMN} < $1" |
| 177 | + )) |
| 178 | + .bind(OffsetDateTime::now_utc()) |
| 179 | + .execute(&pool) |
| 180 | + .await |
| 181 | + { |
| 182 | + rocket::error!("Error deleting expired sessions: {e}"); |
| 183 | + } |
| 184 | + } |
| 185 | + _ = &mut rx => { |
| 186 | + rocket::info!("Session cleanup monitor shutdown"); |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + }); |
| 191 | + |
| 192 | + Ok(()) |
| 193 | + } |
| 194 | + |
| 195 | + pub(super) async fn shutdown(&self) -> SessionResult<()> { |
| 196 | + if let Some(tx) = self.shutdown_tx.lock().await.take() { |
| 197 | + tx.send(0).map_err(|_| { |
| 198 | + SessionError::SetupTeardown("Failed to send shutdown signal".to_string()) |
| 199 | + })?; |
| 200 | + } |
| 201 | + Ok(()) |
| 202 | + } |
| 203 | +} |
0 commit comments