Skip to content

Commit 658a005

Browse files
committed
add sqlx base functions and sqlite storage
1 parent a22c607 commit 658a005

7 files changed

Lines changed: 352 additions & 135 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ exclude = [".github/", ".zed/", ".config/"]
2323
cookie = ["dep:time"]
2424
redis_fred = ["dep:fred"]
2525
rocket_okapi = ["dep:rocket_okapi"]
26-
sqlx_postgres = ["dep:sqlx"]
26+
sqlx_postgres = ["dep:sqlx", "sqlx/postgres"]
27+
sqlx_sqlite = ["dep:sqlx", "sqlx/sqlite"]
2728

2829
[package.metadata.docs.rs]
2930
all-features = true
@@ -42,7 +43,6 @@ rocket = { version = "~0.5.1", features = ["secrets"] }
4243
rocket_okapi = { version = "0.9", optional = true }
4344
sqlx = { version = "0.8", optional = true, default-features = false, features = [
4445
"runtime-tokio",
45-
"postgres",
4646
"time",
4747
] }
4848
thiserror = "2.0"

src/storage.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ pub mod cookie;
3030
#[cfg(any(feature = "redis_fred"))]
3131
pub mod redis;
3232

33-
#[cfg(any(feature = "sqlx_postgres"))]
33+
#[cfg(any(feature = "sqlx_postgres", feature = "sqlx_sqlite"))]
3434
pub mod sqlx;

src/storage/sqlx.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
//! Session storage via sqlx
22
3+
mod base;
4+
use base::*;
5+
36
#[cfg(feature = "sqlx_postgres")]
47
mod postgres;
58
#[cfg(feature = "sqlx_postgres")]
69
pub use postgres::SqlxPostgresStorage;
710

11+
#[cfg(feature = "sqlx_sqlite")]
12+
mod sqlite;
13+
#[cfg(feature = "sqlx_sqlite")]
14+
pub use sqlite::SqlxSqliteStorage;
15+
816
use crate::SessionIdentifier;
917

1018
/**

src/storage/sqlx/base.rs

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
use rocket::{
2+
time::{Duration, OffsetDateTime},
3+
tokio::{
4+
sync::{oneshot, Mutex},
5+
time::interval,
6+
},
7+
};
8+
9+
use crate::error::{SessionError, SessionResult};
10+
11+
pub(super) const ID_COLUMN: &str = "id";
12+
pub(super) const DATA_COLUMN: &str = "data";
13+
pub(super) const EXPIRES_COLUMN: &str = "expires";
14+
15+
pub(super) struct SqlxBase<DB: sqlx::Database> {
16+
pool: sqlx::Pool<DB>,
17+
table_name: String,
18+
index_column: String,
19+
}
20+
21+
impl<DB> SqlxBase<DB>
22+
where
23+
DB: sqlx::Database,
24+
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
25+
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
26+
OffsetDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
27+
{
28+
pub fn new(pool: sqlx::Pool<DB>, table_name: String, index_column: String) -> Self {
29+
SqlxBase {
30+
pool,
31+
table_name,
32+
index_column,
33+
}
34+
}
35+
36+
pub async fn load<'a>(
37+
&self,
38+
id: String,
39+
ttl: Option<u32>,
40+
) -> Result<Option<DB::Row>, sqlx::Error>
41+
where
42+
String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
43+
{
44+
match ttl {
45+
Some(new_ttl) => {
46+
sqlx::query(&load_and_update_ttl_sql(&self.table_name))
47+
.bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into()))
48+
.bind(id)
49+
.fetch_optional(&self.pool)
50+
.await
51+
}
52+
None => {
53+
sqlx::query(&load_sql(&self.table_name))
54+
.bind(id)
55+
.fetch_optional(&self.pool)
56+
.await
57+
}
58+
}
59+
}
60+
61+
pub async fn save<'a, V, I>(
62+
&'a self,
63+
id: String,
64+
value: V,
65+
index: I,
66+
ttl: u32,
67+
) -> Result<(), sqlx::Error>
68+
where
69+
String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
70+
V: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
71+
I: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
72+
{
73+
sqlx::query(&save_sql(&self.table_name, &self.index_column))
74+
.bind(id)
75+
.bind(index)
76+
.bind(value)
77+
.bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into()))
78+
.execute(&self.pool)
79+
.await?;
80+
Ok(())
81+
}
82+
83+
pub async fn delete<'a>(&self, id: String) -> Result<(), sqlx::Error>
84+
where
85+
String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
86+
{
87+
sqlx::query(&delete_sql(&self.table_name))
88+
.bind(id)
89+
.execute(&self.pool)
90+
.await?;
91+
Ok(())
92+
}
93+
}
94+
95+
/// Bind session ID
96+
pub(super) fn load_sql(table_name: &str) -> String {
97+
format!(
98+
"SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{table_name}\" \
99+
WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP"
100+
)
101+
}
102+
103+
/// Bind expiration and session ID
104+
pub(super) fn load_and_update_ttl_sql(table_name: &str) -> String {
105+
format!(
106+
"UPDATE \"{table_name}\" SET {EXPIRES_COLUMN} = $1 \
107+
WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP \
108+
RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}",
109+
)
110+
}
111+
112+
/// Bind the session ID, index, data, and expiration
113+
pub(super) fn save_sql(table_name: &str, index_column: &str) -> String {
114+
format!(
115+
"INSERT INTO \"{table_name}\" ({ID_COLUMN}, {index_column}, {DATA_COLUMN}, {EXPIRES_COLUMN}) \
116+
VALUES ($1, $2, $3, $4) \
117+
ON CONFLICT ({ID_COLUMN}) DO UPDATE SET \
118+
{DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \
119+
{EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}"
120+
)
121+
}
122+
123+
/// Bind the session ID
124+
pub(super) fn delete_sql(table_name: &str) -> String {
125+
format!("DELETE FROM \"{table_name}\" WHERE {ID_COLUMN} = $1")
126+
}
127+
128+
pub(super) fn expires_to_ttl(expires: OffsetDateTime) -> u32 {
129+
(expires - OffsetDateTime::now_utc())
130+
.whole_seconds()
131+
.try_into()
132+
.unwrap_or(0)
133+
}
134+
135+
/// Session cleanup task
136+
#[derive(Default)]
137+
pub(super) struct SqlxCleanupTask {
138+
interval: Option<std::time::Duration>,
139+
shutdown_tx: Mutex<Option<oneshot::Sender<u8>>>,
140+
table_name: String,
141+
}
142+
143+
impl SqlxCleanupTask {
144+
pub(super) fn new(cleanup_interval: Option<std::time::Duration>, table_name: &str) -> Self {
145+
Self {
146+
interval: cleanup_interval,
147+
shutdown_tx: Mutex::default(),
148+
table_name: table_name.to_string(),
149+
}
150+
}
151+
152+
pub(super) async fn setup<DB>(&self, pool: &sqlx::Pool<DB>) -> SessionResult<()>
153+
where
154+
DB: sqlx::Database,
155+
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
156+
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
157+
OffsetDateTime: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
158+
{
159+
let Some(cleanup_interval) = self.interval else {
160+
return Ok(());
161+
};
162+
163+
let (tx, mut rx) = oneshot::channel();
164+
self.shutdown_tx.lock().await.replace(tx);
165+
166+
let pool = pool.clone();
167+
let table_name = self.table_name.clone();
168+
rocket::tokio::spawn(async move {
169+
rocket::info!("Starting session cleanup monitor");
170+
let mut interval = interval(cleanup_interval);
171+
loop {
172+
rocket::tokio::select! {
173+
_ = interval.tick() => {
174+
rocket::debug!("Cleaning up expired sessions");
175+
if let Err(e) = sqlx::query(&format!(
176+
"DELETE FROM \"{table_name}\" WHERE {EXPIRES_COLUMN} < $1"
177+
))
178+
.bind(OffsetDateTime::now_utc())
179+
.execute(&pool)
180+
.await
181+
{
182+
rocket::error!("Error deleting expired sessions: {e}");
183+
}
184+
}
185+
_ = &mut rx => {
186+
rocket::info!("Session cleanup monitor shutdown");
187+
}
188+
}
189+
}
190+
});
191+
192+
Ok(())
193+
}
194+
195+
pub(super) async fn shutdown(&self) -> SessionResult<()> {
196+
if let Some(tx) = self.shutdown_tx.lock().await.take() {
197+
tx.send(0).map_err(|_| {
198+
SessionError::SetupTeardown("Failed to send shutdown signal".to_string())
199+
})?;
200+
}
201+
Ok(())
202+
}
203+
}

0 commit comments

Comments
 (0)