diff --git a/code-rs/core/src/auth.rs b/code-rs/core/src/auth.rs index 5244fff76a4..a40e979fe2f 100644 --- a/code-rs/core/src/auth.rs +++ b/code-rs/core/src/auth.rs @@ -227,6 +227,7 @@ impl CodexAuth { if !access_token_is_still_valid(&tokens.access_token, Utc::now()) { return Err(err); } + self.record_proactive_refresh_fallback(Utc::now()); } } } @@ -282,6 +283,23 @@ impl CodexAuth { self.get_current_auth_json().and_then(|t| t.tokens.clone()) } + fn record_proactive_refresh_fallback(&self, timestamp: DateTime) { + let updated = { + let mut guard = self.auth_dot_json.lock().unwrap(); + let Some(auth_dot_json) = guard.as_mut() else { + return; + }; + auth_dot_json.last_refresh = Some(timestamp); + auth_dot_json.clone() + }; + + if !self.auth_file.as_os_str().is_empty() { + if let Err(err) = write_auth_json(&self.auth_file, &updated) { + tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}"); + } + } + } + /// Consider this private to integration tests. pub fn create_dummy_chatgpt_auth_for_testing() -> Self { let auth_dot_json = AuthDotJson { @@ -363,16 +381,26 @@ fn should_proactively_refresh_auth( last_refresh: Option>, access_token: Option<&str>, ) -> bool { + let now = Utc::now(); if let Some(access_token) = access_token && let Ok(Some(expires_at)) = parse_jwt_expiration(access_token) { - return expires_at - <= Utc::now() - + chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES); + if expires_at <= now { + return true; + } + if expires_at + <= now + chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES) + { + return last_refresh.is_none_or(|last_refresh| { + last_refresh + < now - chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES) + }); + } + return false; } last_refresh.is_some_and(|last_refresh| { - last_refresh < Utc::now() - chrono::Duration::days(28) + last_refresh < now - chrono::Duration::days(28) }) } @@ -534,6 +562,16 @@ pub async fn auth_for_stored_account( Ok(response) => response, Err(err) => { if access_token_is_still_valid(&tokens.access_token, Utc::now()) { + last_refresh = Some(Utc::now()); + if let Err(err) = crate::auth_accounts::upsert_chatgpt_account( + code_home, + tokens.clone(), + last_refresh.expect("last_refresh set"), + account.label.clone(), + false, + ) { + tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}"); + } return Ok(CodexAuth::from_tokens_with_originator_and_mode( tokens, last_refresh, @@ -563,6 +601,16 @@ pub async fn auth_for_stored_account( } } if access_token_is_still_valid(&tokens.access_token, Utc::now()) { + last_refresh = Some(Utc::now()); + if let Err(err) = crate::auth_accounts::upsert_chatgpt_account( + code_home, + tokens.clone(), + last_refresh.expect("last_refresh set"), + account.label.clone(), + false, + ) { + tracing::warn!("failed to persist proactive refresh fallback cooldown: {err}"); + } return Ok(CodexAuth::from_tokens_with_originator_and_mode( tokens, last_refresh, @@ -1432,6 +1480,16 @@ mod tests { assert!(!should_proactively_refresh_auth(Some(stale), Some(&future_access))); assert!(should_proactively_refresh_auth(Some(fresh), Some(&expiring_access))); assert!(should_proactively_refresh_auth(Some(fresh), Some(&expired_access))); + + let just_attempted = Utc::now(); + assert!(!should_proactively_refresh_auth( + Some(just_attempted), + Some(&expiring_access) + )); + assert!(should_proactively_refresh_auth( + Some(just_attempted), + Some(&expired_access) + )); } #[test] @@ -1455,13 +1513,22 @@ mod tests { let _guard = EnvVarGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, server.uri()); let code_home = tempdir().unwrap(); let access_token = build_jwt(serde_json::json!({ "exp": Utc::now().timestamp() + 240 })); - let account = stored_chatgpt_account(access_token.clone()); + let account = stored_chatgpt_account( + access_token.clone(), + Some(Utc::now() - chrono::Duration::minutes(10)), + ); let auth = auth_for_stored_account(code_home.path(), &account, "test") .await .expect("valid cached token should survive proactive refresh failure"); assert_eq!(auth.get_token().await.unwrap(), access_token); + + let returned_last_refresh = auth + .get_current_auth_json() + .and_then(|auth| auth.last_refresh) + .expect("fallback should record refresh cooldown"); + assert!(returned_last_refresh > account.last_refresh.unwrap()); } #[tokio::test] @@ -1477,7 +1544,7 @@ mod tests { let tokens = token_data_for_access(access_token.clone()); let auth = CodexAuth::from_tokens_with_originator_and_mode( tokens, - Some(Utc::now()), + Some(Utc::now() - chrono::Duration::minutes(10)), "test", AuthMode::ChatGPT, ); @@ -1488,6 +1555,16 @@ mod tests { .expect("valid cached token should survive proactive refresh failure"); assert_eq!(token_data.access_token, access_token); + let requests_after_fallback = server.received_requests().await.unwrap().len(); + assert_eq!(requests_after_fallback, 4); + + let token_data_again = auth + .get_token_data() + .await + .expect("fallback should suppress immediate retry"); + + assert_eq!(token_data_again.access_token, access_token); + assert_eq!(server.received_requests().await.unwrap().len(), requests_after_fallback); } #[tokio::test] @@ -1501,7 +1578,7 @@ mod tests { let _guard = EnvVarGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, server.uri()); let code_home = tempdir().unwrap(); let access_token = build_jwt(serde_json::json!({ "exp": Utc::now().timestamp() - 60 })); - let account = stored_chatgpt_account(access_token); + let account = stored_chatgpt_account(access_token, Some(Utc::now())); let err = auth_for_stored_account(code_home.path(), &account, "test") .await @@ -1555,22 +1632,34 @@ mod tests { format!("{header_b64}.{payload_b64}.{signature_b64}") } - fn stored_chatgpt_account(access_token: String) -> crate::auth_accounts::StoredAccount { + fn stored_chatgpt_account( + access_token: String, + last_refresh: Option>, + ) -> crate::auth_accounts::StoredAccount { crate::auth_accounts::StoredAccount { id: "account-id".to_string(), mode: AuthMode::ChatGPT, label: None, openai_api_key: None, tokens: Some(token_data_for_access(access_token)), - last_refresh: Some(Utc::now()), + last_refresh, created_at: None, last_used_at: None, } } fn token_data_for_access(access_token: String) -> TokenData { + let raw_jwt = build_jwt(serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_plan_type": "plus" + } + })); TokenData { - id_token: IdTokenInfo::default(), + id_token: IdTokenInfo { + raw_jwt, + ..Default::default() + }, access_token, refresh_token: "refresh-token".to_string(), account_id: Some("account-id".to_string()),