Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion app/src/main/kotlin/com/google/ai/sample/ApiKeyManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class ApiKeyManager(context: Context) {
private val PREFS_NAME = "api_key_prefs"
private val API_KEYS = "api_keys"
private val CURRENT_KEY_INDEX = "current_key_index"
private val FAILED_KEYS = "failed_keys"

private val prefs: SharedPreferences = context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE)

Expand Down Expand Up @@ -73,6 +74,9 @@ class ApiKeyManager(context: Context) {
setCurrentKeyIndex(0)
}

// Clear this key from failed keys if it was previously marked as failed
removeFailedKey(apiKey)

Log.d(TAG, "Added new API key, total keys: ${keys.size}")
return true
}
Expand All @@ -95,6 +99,9 @@ class ApiKeyManager(context: Context) {
setCurrentKeyIndex(0)
}

// Also remove from failed keys if present
removeFailedKey(apiKey)

Log.d(TAG, "Removed API key, remaining keys: ${keys.size}")
} else {
Log.d(TAG, "API key not found for removal")
Expand Down Expand Up @@ -128,6 +135,122 @@ class ApiKeyManager(context: Context) {
return prefs.getInt(CURRENT_KEY_INDEX, 0)
}

/**
* Mark an API key as failed (e.g., due to 503 error)
* @param apiKey The API key to mark as failed
*/
fun markKeyAsFailed(apiKey: String) {
val failedKeys = getFailedKeys().toMutableList()
if (!failedKeys.contains(apiKey)) {
failedKeys.add(apiKey)
saveFailedKeys(failedKeys)
Log.d(TAG, "Marked API key as failed: ${apiKey.take(5)}...")
}
}

/**
* Remove an API key from the failed keys list
* @param apiKey The API key to remove from failed keys
*/
fun removeFailedKey(apiKey: String) {
val failedKeys = getFailedKeys().toMutableList()
if (failedKeys.remove(apiKey)) {
saveFailedKeys(failedKeys)
Log.d(TAG, "Removed API key from failed keys: ${apiKey.take(5)}...")
}
}

/**
* Get all failed API keys
* @return List of failed API keys
*/
fun getFailedKeys(): List<String> {
val keysString = prefs.getString(FAILED_KEYS, "") ?: ""
return if (keysString.isEmpty()) {
emptyList()
} else {
keysString.split(",")
}
}

/**
* Check if an API key is marked as failed
* @param apiKey The API key to check
* @return True if the key is marked as failed, false otherwise
*/
fun isKeyFailed(apiKey: String): Boolean {
return getFailedKeys().contains(apiKey)
}

/**
* Reset all failed keys
*/
fun resetFailedKeys() {
prefs.edit().remove(FAILED_KEYS).apply()
Log.d(TAG, "Reset all failed keys")
}

/**
* Check if all API keys are marked as failed
* @return True if all keys are failed, false otherwise
*/
fun areAllKeysFailed(): Boolean {
val keys = getApiKeys()
val failedKeys = getFailedKeys()
return keys.isNotEmpty() && failedKeys.size >= keys.size
}

/**
* Get the count of available API keys
* @return The number of API keys
*/
fun getKeyCount(): Int {
return getApiKeys().size
}

/**
* Switch to the next available API key that is not marked as failed
* @return The new API key or null if no valid keys are available
*/
fun switchToNextAvailableKey(): String? {
val keys = getApiKeys()
if (keys.isEmpty()) {
Log.d(TAG, "No API keys available to switch to")
return null
}

val failedKeys = getFailedKeys()
val currentIndex = getCurrentKeyIndex()

// If all keys are failed, reset failed keys and start from the beginning
if (failedKeys.size >= keys.size) {
Log.d(TAG, "All keys are marked as failed, resetting failed keys")
resetFailedKeys()
setCurrentKeyIndex(0)
return keys[0]
}

// Find the next key that is not failed
var nextIndex = (currentIndex + 1) % keys.size
var attempts = 0

while (attempts < keys.size) {
if (!failedKeys.contains(keys[nextIndex])) {
setCurrentKeyIndex(nextIndex)
Log.d(TAG, "Switched to next available key at index $nextIndex")
return keys[nextIndex]
}
nextIndex = (nextIndex + 1) % keys.size
attempts++
}

// If we get here, all keys are failed (shouldn't happen due to earlier check)
Log.d(TAG, "Could not find a non-failed key, resetting failed keys")
resetFailedKeys()
setCurrentKeyIndex(0)
return keys[0]
}

/**
* Save the list of API keys to SharedPreferences
* @param keys The list of API keys to save
Expand All @@ -137,11 +260,20 @@ class ApiKeyManager(context: Context) {
prefs.edit().putString(API_KEYS, keysString).apply()
}

/**
* Save the list of failed API keys to SharedPreferences
* @param keys The list of failed API keys to save
*/
private fun saveFailedKeys(keys: List<String>) {
val keysString = keys.joinToString(",")
prefs.edit().putString(FAILED_KEYS, keysString).apply()
}

/**
* Clear all stored API keys
*/
fun clearAllKeys() {
prefs.edit().remove(API_KEYS).remove(CURRENT_KEY_INDEX).apply()
prefs.edit().remove(API_KEYS).remove(CURRENT_KEY_INDEX).remove(FAILED_KEYS).apply()
Log.d(TAG, "Cleared all API keys")
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.google.ai.sample

import android.content.Context
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewmodel.CreationExtras
Expand Down Expand Up @@ -67,7 +68,9 @@ val GenerativeViewModelFactory = object : ViewModelProvider.Factory {
apiKey = apiKey,
generationConfig = config
)
PhotoReasoningViewModel(generativeModel)
// Pass the ApiKeyManager to the ViewModel for key rotation
val apiKeyManager = ApiKeyManager.getInstance(application)
PhotoReasoningViewModel(generativeModel, apiKeyManager)
}

isAssignableFrom(ChatViewModel::class.java) -> {
Expand Down
Loading