@@ -11,6 +11,15 @@ use crate::state::AppState;
1111use std:: sync:: Arc ;
1212use tauri:: State ;
1313
14+ fn connection_pool_key ( id : i64 , database : & Option < String > ) -> String {
15+ if let Some ( db) = database {
16+ if !db. is_empty ( ) {
17+ return format ! ( "{}:{}" , id, db) ;
18+ }
19+ }
20+ id. to_string ( )
21+ }
22+
1423pub async fn ensure_connection (
1524 state : & State < ' _ , AppState > ,
1625 id : i64 ,
@@ -23,15 +32,7 @@ pub async fn ensure_connection_with_db(
2332 id : i64 ,
2433 database : Option < String > ,
2534) -> Result < Arc < dyn DatabaseDriver > , String > {
26- let key = if let Some ( db) = & database {
27- if !db. is_empty ( ) {
28- format ! ( "{}:{}" , id, db)
29- } else {
30- id. to_string ( )
31- }
32- } else {
33- id. to_string ( )
34- } ;
35+ let key = connection_pool_key ( id, & database) ;
3536
3637 if let Some ( driver) = state. pool_manager . get_connection ( & key) . await {
3738 // Harden: Check if connection still exists in LocalDb
@@ -66,35 +67,27 @@ pub async fn ensure_connection_with_db(
6667 state. pool_manager . connect ( & key, & form) . await
6768}
6869
69- pub async fn execute_with_retry < F , Fut , T > (
70- state : & State < ' _ , AppState > ,
71- id : i64 ,
72- database : Option < String > ,
73- task : F ,
70+ async fn execute_with_retry_core < T , Ensure , EnsureFut , Remove , RemoveFut , Task , TaskFut > (
71+ mut ensure : Ensure ,
72+ mut remove : Remove ,
73+ task : Task ,
7474) -> Result < T , String >
7575where
76- F : Fn ( Arc < dyn DatabaseDriver > ) -> Fut ,
77- Fut : std:: future:: Future < Output = Result < T , String > > ,
76+ Ensure : FnMut ( ) -> EnsureFut ,
77+ EnsureFut : std:: future:: Future < Output = Result < Arc < dyn DatabaseDriver > , String > > ,
78+ Remove : FnMut ( ) -> RemoveFut ,
79+ RemoveFut : std:: future:: Future < Output = ( ) > ,
80+ Task : Fn ( Arc < dyn DatabaseDriver > ) -> TaskFut ,
81+ TaskFut : std:: future:: Future < Output = Result < T , String > > ,
7882{
79- let driver = ensure_connection_with_db ( state , id , database . clone ( ) ) . await ?;
83+ let driver = ensure ( ) . await ?;
8084 match task ( driver. clone ( ) ) . await {
8185 Ok ( res) => Ok ( res) ,
8286 Err ( e) => {
8387 if is_connection_error ( & e) {
84- // Retry once
8588 println ! ( "[Pool] Connection error detected, retrying..." ) ;
86- let key = if let Some ( db) = & database {
87- if !db. is_empty ( ) {
88- format ! ( "{}:{}" , id, db)
89- } else {
90- id. to_string ( )
91- }
92- } else {
93- id. to_string ( )
94- } ;
95-
96- state. pool_manager . remove ( & key) . await ;
97- let driver = ensure_connection_with_db ( state, id, database) . await ?;
89+ remove ( ) . await ;
90+ let driver = ensure ( ) . await ?;
9891 task ( driver) . await . map_err ( |e| {
9992 println ! ( "[Pool] Retry failed: {}" , e) ;
10093 e
@@ -107,6 +100,25 @@ where
107100 }
108101}
109102
103+ pub async fn execute_with_retry < F , Fut , T > (
104+ state : & State < ' _ , AppState > ,
105+ id : i64 ,
106+ database : Option < String > ,
107+ task : F ,
108+ ) -> Result < T , String >
109+ where
110+ F : Fn ( Arc < dyn DatabaseDriver > ) -> Fut ,
111+ Fut : std:: future:: Future < Output = Result < T , String > > ,
112+ {
113+ let key = connection_pool_key ( id, & database) ;
114+ execute_with_retry_core (
115+ || ensure_connection_with_db ( state, id, database. clone ( ) ) ,
116+ || state. pool_manager . remove ( & key) ,
117+ task,
118+ )
119+ . await
120+ }
121+
110122fn is_connection_error ( e : & str ) -> bool {
111123 let lower = e. to_lowercase ( ) ;
112124 lower. contains ( "pool closed" )
@@ -117,3 +129,173 @@ fn is_connection_error(e: &str) -> bool {
117129 || lower. contains ( "closed" )
118130 || lower. contains ( "eof" )
119131}
132+
133+ #[ cfg( test) ]
134+ mod tests {
135+ use super :: execute_with_retry_core;
136+ use crate :: db:: drivers:: DatabaseDriver ;
137+ use crate :: models:: {
138+ QueryResult , SchemaOverview , TableDataResponse , TableInfo , TableMetadata , TableStructure ,
139+ } ;
140+ use async_trait:: async_trait;
141+ use std:: sync:: Arc ;
142+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
143+
144+ struct MockDriver ;
145+
146+ #[ async_trait]
147+ impl DatabaseDriver for MockDriver {
148+ async fn close ( & self ) { }
149+ async fn test_connection ( & self ) -> Result < ( ) , String > {
150+ Ok ( ( ) )
151+ }
152+ async fn list_databases ( & self ) -> Result < Vec < String > , String > {
153+ Ok ( vec ! [ ] )
154+ }
155+ async fn list_tables ( & self , _schema : Option < String > ) -> Result < Vec < TableInfo > , String > {
156+ Ok ( vec ! [ ] )
157+ }
158+ async fn get_table_structure (
159+ & self ,
160+ _schema : String ,
161+ _table : String ,
162+ ) -> Result < TableStructure , String > {
163+ Err ( "Unimplemented" . into ( ) )
164+ }
165+ async fn get_table_metadata (
166+ & self ,
167+ _schema : String ,
168+ _table : String ,
169+ ) -> Result < TableMetadata , String > {
170+ Err ( "Unimplemented" . into ( ) )
171+ }
172+ async fn get_table_ddl ( & self , _schema : String , _table : String ) -> Result < String , String > {
173+ Err ( "Unimplemented" . into ( ) )
174+ }
175+ async fn get_table_data (
176+ & self ,
177+ _schema : String ,
178+ _table : String ,
179+ _page : i64 ,
180+ _limit : i64 ,
181+ _sort_column : Option < String > ,
182+ _sort_direction : Option < String > ,
183+ _filter : Option < String > ,
184+ _order_by : Option < String > ,
185+ ) -> Result < TableDataResponse , String > {
186+ Err ( "Unimplemented" . into ( ) )
187+ }
188+ async fn get_table_data_chunk (
189+ & self ,
190+ _schema : String ,
191+ _table : String ,
192+ _page : i64 ,
193+ _limit : i64 ,
194+ _sort_column : Option < String > ,
195+ _sort_direction : Option < String > ,
196+ _filter : Option < String > ,
197+ _order_by : Option < String > ,
198+ ) -> Result < TableDataResponse , String > {
199+ Err ( "Unimplemented" . into ( ) )
200+ }
201+ async fn execute_query ( & self , _sql : String ) -> Result < QueryResult , String > {
202+ Err ( "Unimplemented" . into ( ) )
203+ }
204+ async fn get_schema_overview (
205+ & self ,
206+ _schema : Option < String > ,
207+ ) -> Result < SchemaOverview , String > {
208+ Err ( "Unimplemented" . into ( ) )
209+ }
210+ }
211+
212+ #[ tokio:: test]
213+ async fn execute_with_retry_retries_once_on_connection_error_and_succeeds ( ) {
214+ let ensure_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
215+ let remove_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
216+ let task_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
217+ let driver: Arc < dyn DatabaseDriver > = Arc :: new ( MockDriver ) ;
218+
219+ let ensure_calls_c = ensure_calls. clone ( ) ;
220+ let ensure_driver = driver. clone ( ) ;
221+ let remove_calls_c = remove_calls. clone ( ) ;
222+ let task_calls_c = task_calls. clone ( ) ;
223+
224+ let result: Result < String , String > = execute_with_retry_core (
225+ move || {
226+ let ensure_calls_c = ensure_calls_c. clone ( ) ;
227+ let ensure_driver = ensure_driver. clone ( ) ;
228+ async move {
229+ ensure_calls_c. fetch_add ( 1 , Ordering :: SeqCst ) ;
230+ Ok ( ensure_driver)
231+ }
232+ } ,
233+ move || {
234+ let remove_calls_c = remove_calls_c. clone ( ) ;
235+ async move {
236+ remove_calls_c. fetch_add ( 1 , Ordering :: SeqCst ) ;
237+ }
238+ } ,
239+ move |_driver| {
240+ let task_calls_c = task_calls_c. clone ( ) ;
241+ async move {
242+ let n = task_calls_c. fetch_add ( 1 , Ordering :: SeqCst ) ;
243+ if n == 0 {
244+ Err ( "[QUERY_ERROR] connection reset by peer" . to_string ( ) )
245+ } else {
246+ Ok ( "ok" . to_string ( ) )
247+ }
248+ }
249+ } ,
250+ )
251+ . await ;
252+
253+ assert_eq ! ( result. unwrap( ) , "ok" ) ;
254+ assert_eq ! ( task_calls. load( Ordering :: SeqCst ) , 2 ) ;
255+ assert_eq ! ( ensure_calls. load( Ordering :: SeqCst ) , 2 ) ;
256+ assert_eq ! ( remove_calls. load( Ordering :: SeqCst ) , 1 ) ;
257+ }
258+
259+ #[ tokio:: test]
260+ async fn execute_with_retry_returns_retry_error_when_second_attempt_fails ( ) {
261+ let ensure_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
262+ let remove_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
263+ let task_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
264+ let driver: Arc < dyn DatabaseDriver > = Arc :: new ( MockDriver ) ;
265+
266+ let ensure_calls_c = ensure_calls. clone ( ) ;
267+ let ensure_driver = driver. clone ( ) ;
268+ let remove_calls_c = remove_calls. clone ( ) ;
269+ let task_calls_c = task_calls. clone ( ) ;
270+
271+ let result: Result < String , String > = execute_with_retry_core (
272+ move || {
273+ let ensure_calls_c = ensure_calls_c. clone ( ) ;
274+ let ensure_driver = ensure_driver. clone ( ) ;
275+ async move {
276+ ensure_calls_c. fetch_add ( 1 , Ordering :: SeqCst ) ;
277+ Ok ( ensure_driver)
278+ }
279+ } ,
280+ move || {
281+ let remove_calls_c = remove_calls_c. clone ( ) ;
282+ async move {
283+ remove_calls_c. fetch_add ( 1 , Ordering :: SeqCst ) ;
284+ }
285+ } ,
286+ move |_driver| {
287+ let task_calls_c = task_calls_c. clone ( ) ;
288+ async move {
289+ task_calls_c. fetch_add ( 1 , Ordering :: SeqCst ) ;
290+ Err ( "[QUERY_ERROR] pool closed" . to_string ( ) )
291+ }
292+ } ,
293+ )
294+ . await ;
295+
296+ assert_eq ! ( result. unwrap_err( ) , "[QUERY_ERROR] pool closed" ) ;
297+ assert_eq ! ( task_calls. load( Ordering :: SeqCst ) , 2 ) ;
298+ assert_eq ! ( ensure_calls. load( Ordering :: SeqCst ) , 2 ) ;
299+ assert_eq ! ( remove_calls. load( Ordering :: SeqCst ) , 1 ) ;
300+ }
301+ }
0 commit comments