diff --git a/.omp/AGENTS.md b/.omp/AGENTS.md index f301d142d..a4e199903 100644 --- a/.omp/AGENTS.md +++ b/.omp/AGENTS.md @@ -36,6 +36,7 @@ Batch fallback: NPU → GPU (Adreno 740) → CPU. E-4B and E-2B support thinking | `:core:voice` | STT, TTS, voice mode, push-to-talk | | `:core:memory` | sqlite-vec JNI, EmbeddingGemma, RAG pipeline | | `:core:wasm` | Chicory Wasm host, bridge functions, resource limits | +| `:core:model-availability` | ModelAvailabilityState, StateBadge, ModelCard, GatedModelStatusRepo | | `:core:ui` | Shared Compose components, Material 3 theme | | `:core:skills` | SkillInterface, SkillRegistry, JSON schema gen | | `:feature:chat` | Chat screen, conversation list, ChatViewModel | @@ -69,6 +70,29 @@ Batch fallback: NPU → GPU (Adreno 740) → CPU. E-4B and E-2B support thinking **Workflow:** Analyse → dispatch (android-developer / llm-engineer) → parallel test-writer + spec-writer → PR with `Closes #N` → parallel code-reviewer + CI → push fixes → owner tests via ADB → owner merges. +### Subagent code changes — recovery pattern + +Task agents run in **ephemeral, isolated worktrees** that are cleaned up on completion. +Their file writes never reach your worktree. To extract their changes, use one of: + +**Option A — diff output (preferred):** Add this to the end of every code-changing assignment: +``` +LAST STEP — output your changes as a patch: +1. Run `git diff` (do NOT omit this step). +2. Copy the ENTIRE diff output into your final message verbatim, + wrapped in a ```diff code block. +Do NOT summarise your changes — I need the raw diff to `git apply`. +``` +Then apply in your worktree: pipe the diff block into `git apply`. + +**Option B — raw file content:** Instruct the agent to `cat` each modified file. +The artifact output will contain the full content; copy it with `write`. + +**Option C — GitHub push:** For larger changes, tell the agent to `git push` its branch, +then `git fetch` + `git merge` from your worktree. + +**Never** assume a `task` agent's file modifications are visible in your worktree. + ## Branch isolation **Do not modify the main checkout directly.** Every session that touches code must use a dedicated worktree: @@ -164,9 +188,8 @@ Write to memory (`memory://root/skills//SKILL.md`) after discovering: - Build/debug quirks (tool flags, adb incantations, test setup) - Architectural invariants that caused a bug (e.g. "gemma4InitMutex required") - Tool invocation patterns that save tokens (rtk, context-mode) - Consult memory via `memory://root` before starting work in an unfamiliar module. -Existing entries: model_loading_order, test_patterns, branch_isolation, rtk_token_saver, adreno_buffer_workaround, github_api_pagination, meal_planner_state, documentation_sync. +Existing entries: model_loading_order, test_patterns, branch_isolation, rtk_token_saver, adreno_buffer_workaround, github_api_pagination, meal_planner_state, documentation_sync, model_availability_state. ## On-demand reference docs diff --git a/app/src/main/java/com/kernel/ai/navigation/KernelNavHost.kt b/app/src/main/java/com/kernel/ai/navigation/KernelNavHost.kt index d4e3ea057..34e4c8a3c 100644 --- a/app/src/main/java/com/kernel/ai/navigation/KernelNavHost.kt +++ b/app/src/main/java/com/kernel/ai/navigation/KernelNavHost.kt @@ -479,6 +479,9 @@ fun KernelNavHost( onNavigateToSettings = { navController.navigate(ROUTE_SETTINGS) }, + onNavigateToModelManagement = { + navController.navigate(ROUTE_MODEL_MANAGEMENT) + }, ) } @@ -499,6 +502,9 @@ fun KernelNavHost( onNavigateToSettings = { navController.navigate(ROUTE_SETTINGS) }, + onNavigateToModelManagement = { + navController.navigate(ROUTE_MODEL_MANAGEMENT) + }, ) } @@ -599,6 +605,9 @@ fun KernelNavHost( composable(ROUTE_VOICE) { VoiceScreen( onBack = { navController.popBackStack() }, + onNavigateToModelManagement = { + navController.navigate(ROUTE_MODEL_MANAGEMENT) + }, ) } diff --git a/core/inference/src/main/java/com/kernel/ai/core/inference/download/DownloadSource.kt b/core/inference/src/main/java/com/kernel/ai/core/inference/download/DownloadSource.kt new file mode 100644 index 000000000..553740ce5 --- /dev/null +++ b/core/inference/src/main/java/com/kernel/ai/core/inference/download/DownloadSource.kt @@ -0,0 +1,14 @@ +package com.kernel.ai.core.inference.download + +/** + * Source of a download request — used to distinguish auto-queued downloads from + * user-initiated ones. + * + * - [AUTO_QUEUED]: Started by the system on startup (required models, tier-preferred models, + * and co-dependent files like SentencePiece). These cannot be cancelled via the UI. + * - [USER_INITIATED]: Started by explicit user action. Can be cancelled. + */ +enum class DownloadSource { + AUTO_QUEUED, + USER_INITIATED, +} diff --git a/core/inference/src/main/java/com/kernel/ai/core/inference/download/KernelModel.kt b/core/inference/src/main/java/com/kernel/ai/core/inference/download/KernelModel.kt index fe37fb96c..e3040269d 100644 --- a/core/inference/src/main/java/com/kernel/ai/core/inference/download/KernelModel.kt +++ b/core/inference/src/main/java/com/kernel/ai/core/inference/download/KernelModel.kt @@ -45,6 +45,13 @@ enum class KernelModel( * Defaults to `true` so existing entries are unaffected. */ val showInModelManagement: Boolean = true, + /** + * If `true`, this model has been superseded and is hidden from the Model Management + * screen and the preferred-model picker. The existing download is not deleted — the + * user must manually delete it through the storage settings. + * Defaults to `false` so existing entries are unaffected. + */ + val isDeprecated: Boolean = false, ) { GEMMA_4_E2B( displayName = "Gemma 4 E-2B", @@ -88,6 +95,7 @@ enum class KernelModel( preferredForTier = null, isGated = true, licenceUrl = "https://huggingface.co/litert-community/embeddinggemma-300m", + isDeprecated = true, ), EMBEDDING_GEMMA_SP_MODEL( diff --git a/core/inference/src/main/java/com/kernel/ai/core/inference/download/ModelDownloadManager.kt b/core/inference/src/main/java/com/kernel/ai/core/inference/download/ModelDownloadManager.kt index 213404884..7be7121f5 100644 --- a/core/inference/src/main/java/com/kernel/ai/core/inference/download/ModelDownloadManager.kt +++ b/core/inference/src/main/java/com/kernel/ai/core/inference/download/ModelDownloadManager.kt @@ -28,6 +28,8 @@ import kotlinx.coroutines.withContext import javax.inject.Inject import javax.inject.Singleton +import kotlinx.coroutines.flow.update + private const val TAG = "ModelDownloadManager" /** @@ -70,6 +72,14 @@ class ModelDownloadManager @Inject constructor( val downloadStates: StateFlow> = _downloadStates.asStateFlow() + /** + * Tracks the [DownloadSource] for each model. Populated when [startDownload] is called. + * Used by the UI layer to determine whether cancel is allowed. + */ + private val _downloadSources: MutableStateFlow> = + MutableStateFlow(emptyMap()) + val downloadSources: StateFlow> = _downloadSources.asStateFlow() + val deviceTier: HardwareTier get() = hardwareProfileDetector.profile.tier init { @@ -87,7 +97,7 @@ class ModelDownloadManager @Inject constructor( } .forEach { model -> Log.i(TAG, "Auto-queuing required model: ${model.displayName}") - startDownload(model) + startDownload(model, source = DownloadSource.AUTO_QUEUED) } // Auto-queue tier-specific optional models (e.g. E-4B on FLAGSHIP) // NOTE: tier is already declared above @@ -98,7 +108,7 @@ class ModelDownloadManager @Inject constructor( } .forEach { model -> Log.i(TAG, "Auto-queuing ${model.displayName} for tier ${tier.name}") - startDownload(model) + startDownload(model, source = DownloadSource.AUTO_QUEUED) } // Auto-trigger gated required models when user signs in scope.launch { @@ -108,7 +118,7 @@ class ModelDownloadManager @Inject constructor( KernelModel.entries .filter { m -> m.isGated && m.isRequired } .filter { m -> _downloadStates.value[m] is DownloadState.NotDownloaded } - .forEach { m -> startDownload(m) } + .forEach { m -> startDownload(m, source = DownloadSource.AUTO_QUEUED) } } } } @@ -127,13 +137,16 @@ class ModelDownloadManager @Inject constructor( * - Otherwise → [ExistingWorkPolicy.REPLACE] to unstick any stale ENQUEUED job that * Samsung's battery manager prevented from dispatching, and to restart FAILED jobs. */ - fun startDownload(model: KernelModel, force: Boolean = false) { + fun startDownload(model: KernelModel, force: Boolean = false, source: DownloadSource = DownloadSource.USER_INITIATED) { if (model.isBundled) return // bundled assets are always available; nothing to download if (!force && model.isDownloaded(context)) { updateState(model, DownloadState.Downloaded(model.localFile(context).absolutePath)) return } + + // Track the download source for UI layer + _downloadSources.update { it.toMutableMap().apply { put(model, source) } } Log.i(TAG, "Enqueuing download for ${model.displayName}") // updateState moved inside coroutine — don't reset progress to 0 if KEEP is chosen @@ -193,7 +206,16 @@ class ModelDownloadManager @Inject constructor( /** Cancel an in-progress download. The partial `.tmp` file is preserved for resumption. */ fun cancelDownload(model: KernelModel) { + // Only user-initiated downloads can be cancelled — auto-queued models are needed + // for the app to function. Check the stored source rather than model.isRequired + // because some required models may be user-initiated (e.g. E2B on FLAGSHIP). + val source = _downloadSources.value[model] ?: DownloadSource.USER_INITIATED + if (source == DownloadSource.AUTO_QUEUED) { + Log.w(TAG, "Refusing to cancel auto-queued download: ${model.displayName}") + return + } workManager.cancelUniqueWork(model.workerTag) + _downloadSources.update { it.toMutableMap().apply { remove(model) } } updateState(model, DownloadState.NotDownloaded) Log.i(TAG, "Cancelled download for ${model.displayName}") } @@ -205,16 +227,16 @@ class ModelDownloadManager @Inject constructor( return if (model.isDownloaded(context)) model.localFile(context).absolutePath else null } - /** - * Re-checks the filesystem for [model] and updates [downloadStates] accordingly. - * Call this after manually deleting a model file so the UI reflects [DownloadState.NotDownloaded]. - */ fun refreshState(model: KernelModel) { val newState = if (model.isDownloaded(context)) { DownloadState.Downloaded(model.localFile(context).absolutePath) } else { DownloadState.NotDownloaded } + // Clear stale source tracking since the model is no longer actively downloading + if (newState !is DownloadState.Downloading) { + _downloadSources.update { it.toMutableMap().apply { remove(model) } } + } updateState(model, newState) Log.i(TAG, "Refreshed state for ${model.displayName}: $newState") } @@ -267,7 +289,7 @@ class ModelDownloadManager @Inject constructor( // ------------------------------------------------------------------------- private fun updateState(model: KernelModel, state: DownloadState) { - _downloadStates.value = _downloadStates.value.toMutableMap().apply { put(model, state) } + _downloadStates.update { it.toMutableMap().apply { put(model, state) } } } // Issue 3 fix: guard against launching duplicate observeWorkInfo coroutines diff --git a/core/inference/src/test/java/com/kernel/ai/core/inference/download/KernelModelTest.kt b/core/inference/src/test/java/com/kernel/ai/core/inference/download/KernelModelTest.kt new file mode 100644 index 000000000..0445a046f --- /dev/null +++ b/core/inference/src/test/java/com/kernel/ai/core/inference/download/KernelModelTest.kt @@ -0,0 +1,26 @@ +package com.kernel.ai.core.inference.download + +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class KernelModelTest { + + @Test + fun `isDeprecated defaults to false for all models`() { + KernelModel.entries.forEach { model -> + // Only SM8550 is explicitly deprecated; all others default to false + if (model == KernelModel.EMBEDDING_GEMMA_300M_SM8550) { + assertTrue(model.isDeprecated, "Expected ${model.name} to be deprecated") + } else { + assertFalse(model.isDeprecated, "Expected ${model.name} isDeprecated to be false") + } + } + } + + @Test + fun `deprecated model is excluded from preferredForTier matches`() { + // SM8550 is deprecated — it should not match any tier preference logic + assertTrue(KernelModel.EMBEDDING_GEMMA_300M_SM8550.isDeprecated) + } +} diff --git a/core/model-availability/build.gradle.kts b/core/model-availability/build.gradle.kts new file mode 100644 index 000000000..f7f691428 --- /dev/null +++ b/core/model-availability/build.gradle.kts @@ -0,0 +1,66 @@ +plugins { + alias(libs.plugins.android.library) + alias(libs.plugins.kotlin.android) + alias(libs.plugins.kotlin.compose) + alias(libs.plugins.ksp) + alias(libs.plugins.hilt) +} + +android { + namespace = "com.kernel.ai.core.model.availability" + compileSdk = libs.versions.compileSdk.get().toInt() + + defaultConfig { + minSdk = libs.versions.minSdk.get().toInt() + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + } + + buildFeatures { + compose = true + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + + kotlinOptions { + jvmTarget = "17" + } + + testOptions { + unitTests.all { it.useJUnitPlatform() } + } +} + +dependencies { + implementation(project(":core:ui")) + implementation(project(":core:inference")) + + // Compose + implementation(platform(libs.compose.bom)) + implementation(libs.compose.ui) + implementation(libs.compose.material3) + implementation(libs.compose.material.icons) + implementation(libs.compose.foundation) + implementation(libs.compose.ui.tooling.preview) + implementation(libs.lifecycle.viewmodel.compose) + implementation(libs.lifecycle.runtime.compose) + + ksp(libs.hilt.compiler) + + // Hilt + implementation(libs.hilt.android) + + // DataStore + implementation(libs.datastore.preferences) + + debugImplementation(libs.compose.ui.tooling) + + compileOnly(libs.compose.ui.test.manifest) + + testImplementation(libs.junit.jupiter) + testImplementation(libs.mockk) + testImplementation(libs.coroutines.test) + testImplementation(libs.compose.ui.test.junit4) +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/AvailabilitySummary.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/AvailabilitySummary.kt new file mode 100644 index 000000000..f38a62746 --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/AvailabilitySummary.kt @@ -0,0 +1,63 @@ +package com.kernel.ai.core.model.availability + +import com.kernel.ai.core.inference.download.DownloadSource +import com.kernel.ai.core.inference.download.DownloadState +import com.kernel.ai.core.inference.download.KernelModel + +/** + * Summary counts of models in each [ModelAvailabilityState]. + */ +data class AvailabilitySummary( + val total: Int, + val ready: Int = 0, + val preparing: Int = 0, + val actionRequired: Int = 0, + val unavailable: Int = 0, +) { + val displaySummary: String get() { + return "$ready of $total models available" + } +} + +/** + * Computes an [AvailabilitySummary] from a list of models and their download states. + */ +fun computeAvailabilitySummary( + models: List, + downloadStates: Map, + hfAuth: Boolean, + downloadSources: Map = emptyMap(), + gatedStatuses: Map = emptyMap(), +): AvailabilitySummary { + var ready = 0 + var preparing = 0 + var actionRequired = 0 + var unavailable = 0 + + for (model in models) { + val state = downloadStates[model] ?: DownloadState.NotDownloaded + val source = downloadSources[model] ?: DownloadSource.USER_INITIATED + val gatedStatus = gatedStatuses[model] ?: GatedModelStatus.NONE + val availability = state.toAvailability( + model = model, + hfAuth = hfAuth, + source = source, + gated = gatedStatus, + ) + when (availability) { + is ModelAvailabilityState.Ready -> ready++ + is ModelAvailabilityState.Preparing -> preparing++ + is ModelAvailabilityState.ActionRequired -> actionRequired++ + is ModelAvailabilityState.Unavailable -> unavailable++ + ModelAvailabilityState.NotDisplayed -> {} // NotDisplayed = no badge shown + } + } + + return AvailabilitySummary( + total = models.size, + ready = ready, + preparing = preparing, + actionRequired = actionRequired, + unavailable = unavailable, + ) +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/DownloadStateMapper.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/DownloadStateMapper.kt new file mode 100644 index 000000000..1eecb2f3b --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/DownloadStateMapper.kt @@ -0,0 +1,64 @@ +package com.kernel.ai.core.model.availability + +import com.kernel.ai.core.inference.download.DownloadSource +import com.kernel.ai.core.inference.download.DownloadState +import com.kernel.ai.core.inference.download.KernelModel + +/** + * Maps the core [DownloadState] to the UI-layer [ModelAvailabilityState]. + * + * Truth table (see docs/model-availability-ux-patterns.md): + * + * | DownloadState | isBundled | isGated | hfAuth | source | gated | → Result | + * |----------------------|-----------|---------|--------|--------------|-----------------|-----------------------------------| + * | Downloaded(*) | any | any | any | any | any | Ready | + * | NotDownloaded | true | any | any | any | any | Ready | + * | Downloading(p) | any | any | any | any | any | Preparing(p, source == AUTO_QUEUED)| + * | NotDownloaded | false | true | false | any | any | ActionRequired(SignInRequired) | + * | NotDownloaded | false | true | true | any | APPROVAL_PENDING | ActionRequired(ApprovalPending) | + * | NotDownloaded | false | true | true | any | ACCESS_DENIED | Unavailable(AccessDenied) | + * | NotDownloaded | false | false | any | AUTO_QUEUED | any | Preparing(0f, isAutoQueued = true)| + * | NotDownloaded | false | false | any | USER_INITIATED| any | (no badge — primary action only) | + * | Error(licence=T) | any | any | any | any | any | ActionRequired(LicenseRequired) | + * | Error(message) | any | any | any | any | any | ActionRequired(DownloadFailed(msg))| + */ +fun DownloadState.toAvailability( + model: KernelModel, + hfAuth: Boolean, + source: DownloadSource = DownloadSource.USER_INITIATED, + gated: GatedModelStatus = GatedModelStatus.NONE, +): ModelAvailabilityState { + return when (this) { + is DownloadState.Downloaded -> ModelAvailabilityState.Ready + is DownloadState.Downloading -> ModelAvailabilityState.Preparing( + progress = progress, + isAutoQueued = source == DownloadSource.AUTO_QUEUED, + ) + is DownloadState.NotDownloaded -> { + if (model.isBundled) return ModelAvailabilityState.Ready + if (model.isGated) { + if (!hfAuth) return ModelAvailabilityState.ActionRequired(ActionReason.SignInRequired) + return when (gated) { + GatedModelStatus.APPROVAL_PENDING -> ModelAvailabilityState.ActionRequired(ActionReason.ApprovalPending) + GatedModelStatus.ACCESS_DENIED -> ModelAvailabilityState.Unavailable(UnavailableReason.AccessDenied) + else -> ModelAvailabilityState.NotDisplayed + } + } + // Ungated model — source determines display + when (source) { + DownloadSource.AUTO_QUEUED -> ModelAvailabilityState.Preparing( + progress = 0f, + isAutoQueued = true, + ) + DownloadSource.USER_INITIATED -> ModelAvailabilityState.NotDisplayed + } + } + is DownloadState.Error -> { + if (licenceRequired) { + ModelAvailabilityState.ActionRequired(ActionReason.LicenseRequired) + } else { + ModelAvailabilityState.ActionRequired(ActionReason.DownloadFailed(message)) + } + } +} +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/GatedModelStatus.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/GatedModelStatus.kt new file mode 100644 index 000000000..505fb8376 --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/GatedModelStatus.kt @@ -0,0 +1,19 @@ +package com.kernel.ai.core.model.availability + +/** + * Represents the gated-model access status for a [KernelModel] that is gated on + * HuggingFace. Persisted in DataStore; status decisions are made server-side + * by the HuggingFace moderation system. + * + * - [NONE]: No status known — user can attempt to download. The backend will + * report the result (approval pending / denied / success). + * - [APPROVAL_PENDING]: User has requested access, waiting for HF moderation. + * - [APPROVED]: Access granted — download can proceed. + * - [ACCESS_DENIED]: HF moderation rejected the access request. + */ +enum class GatedModelStatus { + NONE, + APPROVAL_PENDING, + APPROVED, + ACCESS_DENIED, +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/GatedModelStatusRepository.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/GatedModelStatusRepository.kt new file mode 100644 index 000000000..2dccb116c --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/GatedModelStatusRepository.kt @@ -0,0 +1,72 @@ +package com.kernel.ai.core.model.availability + +import android.content.Context +import android.util.Log +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.edit +import androidx.datastore.preferences.core.stringPreferencesKey +import androidx.datastore.preferences.preferencesDataStore +import com.kernel.ai.core.inference.download.KernelModel +import dagger.hilt.android.qualifiers.ApplicationContext +import androidx.datastore.preferences.core.emptyPreferences +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map +import java.io.IOException +import javax.inject.Inject +import javax.inject.Singleton + +private const val TAG = "GatedModelStatusRepo" + +private val Context.gatedStatusDataStore: DataStore by + preferencesDataStore(name = "gated_model_status") + +/** + * DataStore-backed repository for per-model [GatedModelStatus]. + * + * This is a lightweight scaffolding until the real HuggingFace moderation + * webhook is implemented. The debug toggle in Settings → About provides + * manual QA control over each model's status. + */ +@Singleton +class GatedModelStatusRepository @Inject constructor( + @ApplicationContext private val context: Context, +) { + /** Observe the status for a specific model. */ + fun get(model: KernelModel): Flow = + context.gatedStatusDataStore.data + .catch { e -> + if (e is IOException) { + Log.w(TAG, "DataStore read error for ${model.modelId}", e) + emit(emptyPreferences()) + } else throw e + } + .map { prefs -> + val raw = prefs[key(model)] ?: return@map GatedModelStatus.NONE + try { + GatedModelStatus.valueOf(raw) + } catch (_: IllegalArgumentException) { + GatedModelStatus.NONE + } + } + + /** Set the status for a specific model. */ + suspend fun set(model: KernelModel, status: GatedModelStatus) { + context.gatedStatusDataStore.edit { prefs -> + if (status == GatedModelStatus.NONE) { + prefs.remove(key(model)) + } else { + prefs[key(model)] = status.name + } + } + Log.i(TAG, "Set ${model.modelId} → $status") + } + + /** Snapshot read (non-flow). Useful for one-shot checks. */ + suspend fun getSnapshot(model: KernelModel): GatedModelStatus = + get(model).first() + + private fun key(model: KernelModel) = stringPreferencesKey("gated_${model.modelId}") +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelAvailabilityState.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelAvailabilityState.kt new file mode 100644 index 000000000..aa0d1ca16 --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelAvailabilityState.kt @@ -0,0 +1,42 @@ +package com.kernel.ai.core.model.availability + +/** + * Canonical 4-state model for model availability in the UI layer. + * + * Mapped from [com.kernel.ai.core.inference.download.DownloadState] via + * [DownloadStateMapper.toAvailability]. + * + * States: + * - [Ready]: Model is on disk and ready for inference. + * - [Preparing]: Download is in progress or auto-queued. + * - [ActionRequired]: User must take an action (sign in, accept licence, etc.). + * - [Unavailable]: Model cannot be used on this device or at this time. + */ +sealed class ModelAvailabilityState { + data object Ready : ModelAvailabilityState() + data class Preparing( + val progress: Float = 0f, + val isAutoQueued: Boolean = false, + ) : ModelAvailabilityState() + data class ActionRequired(val reason: ActionReason) : ModelAvailabilityState() + data class Unavailable(val reason: UnavailableReason) : ModelAvailabilityState() + /** Internal sentinel — the mapper returns this when no badge should be shown. */ + internal data object NotDisplayed : ModelAvailabilityState() +} + +sealed class ActionReason { + data object SignInRequired : ActionReason() + data object LicenseRequired : ActionReason() + data class AccessApprovalRequired(val providerName: String) : ActionReason() + data object ApprovalPending : ActionReason() + data object InsufficientStorage : ActionReason() + data class DownloadFailed(val message: String) : ActionReason() +} + +sealed class UnavailableReason { + data object AccessDenied : UnavailableReason() + data object ProviderUnavailable : UnavailableReason() + data object ModelRemoved : UnavailableReason() + data class UnsupportedDevice(val message: String) : UnavailableReason() + data object NotBundled : UnavailableReason() +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelAvailabilityStrings.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelAvailabilityStrings.kt new file mode 100644 index 000000000..0ce517065 --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelAvailabilityStrings.kt @@ -0,0 +1,75 @@ +package com.kernel.ai.core.model.availability + +/** + * User-facing label and supporting text for each [ModelAvailabilityState]. + * + * These are plain data values (no Compose dependency). The UI layer reads these + * to populate text elements — avoids leaking enum names into user-facing copy. + */ +data class AvailabilityStrings( + val label: String, + val supportingText: String? = null, +) + +fun ModelAvailabilityState.toStrings(): AvailabilityStrings = when (this) { + is ModelAvailabilityState.Ready -> AvailabilityStrings( + label = "Ready", + supportingText = null, + ) + is ModelAvailabilityState.Preparing -> AvailabilityStrings( + label = if (isAutoQueued) "Waiting" else "Downloading", + supportingText = if (isAutoQueued) "Starting soon…" else null, + ) + is ModelAvailabilityState.ActionRequired -> when (reason) { + is ActionReason.SignInRequired -> AvailabilityStrings( + label = "Sign in required", + supportingText = "Sign in to HuggingFace to download this model", + ) + is ActionReason.LicenseRequired -> AvailabilityStrings( + label = "License required", + supportingText = "Accept the model license on HuggingFace", + ) + is ActionReason.ApprovalPending -> AvailabilityStrings( + label = "Approval pending", + supportingText = "Waiting for HuggingFace moderation", + ) + is ActionReason.AccessApprovalRequired -> AvailabilityStrings( + label = "Access request required", + supportingText = "Request access on ${reason.providerName}", + ) + is ActionReason.InsufficientStorage -> AvailabilityStrings( + label = "Insufficient storage", + supportingText = "Free up space to download this model", + ) + is ActionReason.DownloadFailed -> AvailabilityStrings( + label = "Download failed", + supportingText = reason.message, + ) + } + is ModelAvailabilityState.Unavailable -> when (reason) { + is UnavailableReason.AccessDenied -> AvailabilityStrings( + label = "Access denied", + supportingText = "Your access request was denied by the provider", + ) + is UnavailableReason.ProviderUnavailable -> AvailabilityStrings( + label = "Provider unavailable", + supportingText = "The model provider is temporarily unavailable", + ) + is UnavailableReason.ModelRemoved -> AvailabilityStrings( + label = "Model removed", + supportingText = "This model has been removed from the provider", + ) + is UnavailableReason.UnsupportedDevice -> AvailabilityStrings( + label = "Unsupported device", + supportingText = reason.message, + ) + is UnavailableReason.NotBundled -> AvailabilityStrings( + label = "Not available", + supportingText = "This model is not bundled with the app", + ) + } + ModelAvailabilityState.NotDisplayed -> AvailabilityStrings( + label = "", + supportingText = null, + ) +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelCard.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelCard.kt new file mode 100644 index 000000000..9c5c92c2d --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/ModelCard.kt @@ -0,0 +1,263 @@ +package com.kernel.ai.core.model.availability + +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.animation.core.animateFloatAsState +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.width +import androidx.compose.material.icons.Icons +import androidx.compose.material3.LinearProgressIndicator +import androidx.compose.material.icons.filled.Lock +import androidx.compose.material3.Button +import androidx.compose.material3.Card +import androidx.compose.material3.CardDefaults +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedButton +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.runtime.getValue +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp + +/** + * Full-size model card used in Model Management and Settings. + * + * Shows: + * - [stateBadge] at top-right + * - Model name and description + * - Optional lock icon for gated models + * @param onPrimaryAction Click handler for the primary action button. Null = no button. + * @param primaryActionLabel Label for the primary action button. Null = auto from state. + * @param onSecondaryAction Click handler for a secondary action (e.g. Delete) shown beside + * the primary action when the model is downloaded. Null = no secondary button. + * @param secondaryActionLabel Label for the secondary action button. Ignored when + * [onSecondaryAction] is null. + * @param modifier Modifier for the card. + */ +@Composable +fun ModelCard( + title: String, + description: String?, + state: ModelAvailabilityState, + showLock: Boolean = false, + onPrimaryAction: (() -> Unit)? = null, + primaryActionLabel: String? = null, + onSecondaryAction: (() -> Unit)? = null, + secondaryActionLabel: String? = null, + modifier: Modifier = Modifier, +) { + val actionLabel = primaryActionLabel ?: defaultActionLabel(state) + + Card( + modifier = modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f), + ), + ) { + Column(modifier = Modifier.padding(16.dp)) { + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f), + ) { + if (showLock) { + Icon( + imageVector = Icons.Default.Lock, + contentDescription = "Gated model", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant, + ) + Spacer(Modifier.width(6.dp)) + } + Text( + text = title, + style = MaterialTheme.typography.titleMedium, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + ) + } + Spacer(Modifier.width(8.dp)) + StateBadge(state = state) + } + + if (!description.isNullOrBlank()) { + Spacer(Modifier.height(4.dp)) + Text( + text = description, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 2, + overflow = TextOverflow.Ellipsis, + ) + } + + if (state is ModelAvailabilityState.Preparing) { + Spacer(Modifier.height(8.dp)) + val animatedProgress by animateFloatAsState( + targetValue = state.progress.coerceIn(0f, 1f), + label = "downloadProgress", + ) + LinearProgressIndicator( + progress = { animatedProgress }, + modifier = Modifier.fillMaxWidth(), + ) + } + + if (onPrimaryAction != null && actionLabel != null) { + Spacer(Modifier.height(12.dp)) + when (state) { + is ModelAvailabilityState.Preparing -> { + // Auto-queued: no action button; User-initiated: show cancel + if (!state.isAutoQueued) { + Button( + onClick = { onPrimaryAction?.invoke() }, + modifier = Modifier.fillMaxWidth(), + ) { + Text(actionLabel ?: "Cancel") + } + } + } + is ModelAvailabilityState.Unavailable, + ModelAvailabilityState.NotDisplayed -> { + // Full-width outlined button for unavailable/not-displayed + OutlinedButton( + onClick = onPrimaryAction, + modifier = Modifier.fillMaxWidth(), + ) { + Text(actionLabel) + } + } + is ModelAvailabilityState.Ready -> { + if (onSecondaryAction != null && secondaryActionLabel != null) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp), + ) { + Button( + onClick = onPrimaryAction, + modifier = Modifier.weight(1f), + ) { + Text(actionLabel) + } + OutlinedButton(onClick = onSecondaryAction) { + Text(secondaryActionLabel) + } + } + } else { + Button( + onClick = onPrimaryAction, + modifier = Modifier.fillMaxWidth(), + ) { + Text(actionLabel) + } + } + } + is ModelAvailabilityState.ActionRequired -> { + Button( + onClick = onPrimaryAction, + modifier = Modifier.fillMaxWidth(), + ) { + Text(actionLabel) + } + } + } + } + } + } +} + +/** + * Compact variant used in VoiceScreen and Chat onboarding. + * No action button — just the name, optional description, and state badge. + */ +@Composable +fun ModelCardCompact( + title: String, + description: String?, + state: ModelAvailabilityState, + showLock: Boolean = false, + modifier: Modifier = Modifier, +) { + Column(modifier = modifier) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 8.dp, horizontal = 16.dp), + verticalAlignment = Alignment.CenterVertically, + ) { + Column(modifier = Modifier.weight(1f)) { + Row(verticalAlignment = Alignment.CenterVertically) { + if (showLock) { + Icon( + imageVector = Icons.Default.Lock, + contentDescription = "Gated model", + modifier = Modifier.size(14.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant, + ) + Spacer(Modifier.width(4.dp)) + } + Text( + text = title, + style = MaterialTheme.typography.bodyMedium, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + ) + } + if (!description.isNullOrBlank()) { + Text( + text = description, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + ) + } + } + Spacer(Modifier.width(8.dp)) + StateBadge(state = state) + } + if (state is ModelAvailabilityState.Preparing) { + val animatedProgress by animateFloatAsState( + targetValue = state.progress.coerceIn(0f, 1f), + label = "downloadProgress", + ) + LinearProgressIndicator( + progress = { animatedProgress }, + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp), + ) + } + } +} + +/** + * Returns a default action label for a given availability state. + * Used when [ModelCard] is constructed without [primaryActionLabel]. + */ +fun defaultActionLabel(state: ModelAvailabilityState): String? = when (state) { + is ModelAvailabilityState.Ready -> "Update" + is ModelAvailabilityState.Preparing -> if (state.isAutoQueued) null else "Cancel" + is ModelAvailabilityState.ActionRequired -> when (state.reason) { + is ActionReason.SignInRequired -> "Sign in to HuggingFace" + is ActionReason.LicenseRequired -> "View license" + is ActionReason.ApprovalPending -> null + is ActionReason.AccessApprovalRequired -> "Request access" + is ActionReason.InsufficientStorage -> "Manage storage" + is ActionReason.DownloadFailed -> "Retry download" + } + is ModelAvailabilityState.Unavailable -> null + ModelAvailabilityState.NotDisplayed -> "Download" +} diff --git a/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/StateBadge.kt b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/StateBadge.kt new file mode 100644 index 000000000..bfe81667f --- /dev/null +++ b/core/model-availability/src/main/java/com/kernel/ai/core/model/availability/StateBadge.kt @@ -0,0 +1,110 @@ +package com.kernel.ai.core.model.availability + +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Block +import androidx.compose.material.icons.filled.CheckCircle +import androidx.compose.material.icons.filled.HourglassEmpty +import androidx.compose.material.icons.filled.WarningAmber +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Surface +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.unit.dp + +/** + * Compact state badge chip for model availability. + * + * Uses M3 `AssistChip` shape at 28dp height with an icon + label. + * Maps [ModelAvailabilityState] to the appropriate icon and color scheme. + * + * @param state The availability state to render. + * @param modifier Modifier for the chip. + */ +@Composable +fun StateBadge( + state: ModelAvailabilityState, + modifier: Modifier = Modifier, +) { + val (label, icon, containerColor, contentColor) = when (state) { + is ModelAvailabilityState.Ready -> BadgeValues( + label = "Ready", + icon = Icons.Default.CheckCircle, + containerColor = MaterialTheme.colorScheme.primaryContainer, + contentColor = MaterialTheme.colorScheme.onPrimaryContainer, + ) + is ModelAvailabilityState.Preparing -> BadgeValues( + label = if (state.isAutoQueued) "Waiting" else "Downloading", + icon = Icons.Default.HourglassEmpty, + containerColor = MaterialTheme.colorScheme.secondaryContainer, + contentColor = MaterialTheme.colorScheme.onSecondaryContainer, + ) + is ModelAvailabilityState.ActionRequired -> BadgeValues( + label = when (state.reason) { + is ActionReason.SignInRequired -> "Sign in" + is ActionReason.LicenseRequired -> "License" + is ActionReason.ApprovalPending -> "Pending" + is ActionReason.AccessApprovalRequired -> "Access" + is ActionReason.InsufficientStorage -> "Storage" + is ActionReason.DownloadFailed -> "Failed" + }, + icon = Icons.Default.WarningAmber, + containerColor = MaterialTheme.colorScheme.errorContainer, + contentColor = MaterialTheme.colorScheme.onErrorContainer, + ) + is ModelAvailabilityState.Unavailable -> BadgeValues( + label = when (state.reason) { + is UnavailableReason.AccessDenied -> "Denied" + is UnavailableReason.ProviderUnavailable -> "Unavailable" + is UnavailableReason.ModelRemoved -> "Removed" + is UnavailableReason.UnsupportedDevice -> "Unsupported" + is UnavailableReason.NotBundled -> "Not available" + }, + icon = Icons.Default.Block, + containerColor = MaterialTheme.colorScheme.surfaceVariant, + contentColor = MaterialTheme.colorScheme.onSurfaceVariant, + ) + ModelAvailabilityState.NotDisplayed -> return // Don't render a badge + } + + Surface( + shape = RoundedCornerShape(8.dp), + color = containerColor, + modifier = modifier, + ) { + Row( + modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp), + verticalAlignment = Alignment.CenterVertically, + ) { + Icon( + imageVector = icon, + contentDescription = null, + modifier = Modifier.size(14.dp), + tint = contentColor, + ) + Spacer(Modifier.width(4.dp)) + Text( + text = label, + style = MaterialTheme.typography.labelSmall, + color = contentColor, + ) + } + } +} + +private data class BadgeValues( + val label: String, + val icon: ImageVector, + val containerColor: Color, + val contentColor: Color, +) diff --git a/core/model-availability/src/test/java/com/kernel/ai/core/model/availability/DownloadStateMapperTest.kt b/core/model-availability/src/test/java/com/kernel/ai/core/model/availability/DownloadStateMapperTest.kt new file mode 100644 index 000000000..6ae841c1d --- /dev/null +++ b/core/model-availability/src/test/java/com/kernel/ai/core/model/availability/DownloadStateMapperTest.kt @@ -0,0 +1,145 @@ +package com.kernel.ai.core.model.availability + +import com.kernel.ai.core.inference.download.DownloadSource +import com.kernel.ai.core.inference.download.DownloadState +import com.kernel.ai.core.inference.download.KernelModel +import com.kernel.ai.core.model.availability.ModelAvailabilityState.ActionRequired +import com.kernel.ai.core.model.availability.ModelAvailabilityState.Preparing +import com.kernel.ai.core.model.availability.ModelAvailabilityState.Unavailable +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +class DownloadStateMapperTest { + + private val ungatedModel = KernelModel.GEMMA_4_E2B + private val gatedModel = KernelModel.EMBEDDING_GEMMA_300M + private val bundledModel = KernelModel.MINI_LM + + @Test + fun `downloaded maps to Ready`() { + val result = DownloadState.Downloaded("/path/to/model") + .toAvailability(ungatedModel, hfAuth = false) + assertEquals(ModelAvailabilityState.Ready, result) + } + + @Test + fun `bundled model always Ready even when NotDownloaded`() { + val result = DownloadState.NotDownloaded + .toAvailability(bundledModel, hfAuth = false) + assertEquals(ModelAvailabilityState.Ready, result) + } + + @Test + fun `downloading maps to Preparing with progress`() { + val result = DownloadState.Downloading(progress = 0.42f) + .toAvailability(ungatedModel, hfAuth = false) + assertEquals(Preparing(progress = 0.42f, isAutoQueued = false), result) + } + + @Test + fun `downloading with AUTO_QUEUED source maps to Preparing isAutoQueued true`() { + val result = DownloadState.Downloading(progress = 0.5f) + .toAvailability(ungatedModel, hfAuth = false, source = DownloadSource.AUTO_QUEUED) + assertEquals(Preparing(progress = 0.5f, isAutoQueued = true), result) + } + + @Nested + inner class GatedModels { + + @Test + fun `not downloaded gated model without HF auth maps to SignInRequired`() { + val result = DownloadState.NotDownloaded + .toAvailability(gatedModel, hfAuth = false) + assertEquals(ActionRequired(ActionReason.SignInRequired), result) + } + + @Test + fun `not downloaded gated model with HF auth and APPROVAL_PENDING maps to ApprovalPending`() { + val result = DownloadState.NotDownloaded + .toAvailability(gatedModel, hfAuth = true, gated = GatedModelStatus.APPROVAL_PENDING) + assertEquals(ActionRequired(ActionReason.ApprovalPending), result) + } + + @Test + fun `not downloaded gated model with HF auth and ACCESS_DENIED maps to AccessDenied`() { + val result = DownloadState.NotDownloaded + .toAvailability(gatedModel, hfAuth = true, gated = GatedModelStatus.ACCESS_DENIED) + assertEquals(Unavailable(UnavailableReason.AccessDenied), result) + } + + @Test + fun `not downloaded gated model with HF auth and NONE status maps to NotDisplayed`() { + val result = DownloadState.NotDownloaded + .toAvailability(gatedModel, hfAuth = true, gated = GatedModelStatus.NONE) + assertEquals(ModelAvailabilityState.NotDisplayed, result) + } + + @Test + fun `not downloaded gated model with HF auth and APPROVED status maps to NotDisplayed`() { + val result = DownloadState.NotDownloaded + .toAvailability(gatedModel, hfAuth = true, gated = GatedModelStatus.APPROVED) + assertEquals(ModelAvailabilityState.NotDisplayed, result) + } + } + + @Nested + inner class UngatedNotDownloaded { + + @Test + fun `not downloaded ungated model with AUTO_QUEUED source maps to Preparing isAutoQueued`() { + val result = DownloadState.NotDownloaded + .toAvailability(ungatedModel, hfAuth = false, source = DownloadSource.AUTO_QUEUED) + assertEquals(Preparing(progress = 0f, isAutoQueued = true), result) + } + + @Test + fun `not downloaded ungated model with USER_INITIATED source maps to NotDisplayed`() { + val result = DownloadState.NotDownloaded + .toAvailability(ungatedModel, hfAuth = false, source = DownloadSource.USER_INITIATED) + assertEquals(ModelAvailabilityState.NotDisplayed, result) + } + } + + @Nested + inner class ErrorStates { + + @Test + fun `error with licenceRequired maps to LicenseRequired`() { + val result = DownloadState.Error( + message = "Licence not accepted", + licenceRequired = true, + ).toAvailability(ungatedModel, hfAuth = false) + assertEquals(ActionRequired(ActionReason.LicenseRequired), result) + } + + @Test + fun `error without licence maps to DownloadFailed`() { + val result = DownloadState.Error( + message = "Network timeout", + licenceRequired = false, + ).toAvailability(ungatedModel, hfAuth = false) + assertEquals(ActionRequired(ActionReason.DownloadFailed("Network timeout")), result) + } + } + + @Nested + inner class EdgeCases { + + @Test + fun `downloaded state regardless of gated or auth returns Ready`() { + val result = DownloadState.Downloaded("/path") + .toAvailability(gatedModel, hfAuth = false) + assertEquals(ModelAvailabilityState.Ready, result) + } + + @Test + fun `downloading regardless of gated or auth returns Preparing`() { + val result = DownloadState.Downloading(progress = 0.1f) + .toAvailability(gatedModel, hfAuth = false) + assertEquals(Preparing(progress = 0.1f, isAutoQueued = false), result) + } + } +} diff --git a/core/ui/src/main/java/com/kernel/ai/core/ui/CollapsibleSectionHeader.kt b/core/ui/src/main/java/com/kernel/ai/core/ui/CollapsibleSectionHeader.kt new file mode 100644 index 000000000..45a888154 --- /dev/null +++ b/core/ui/src/main/java/com/kernel/ai/core/ui/CollapsibleSectionHeader.kt @@ -0,0 +1,75 @@ +package com.kernel.ai.core.ui + +import androidx.compose.animation.core.animateFloatAsState +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.KeyboardArrowDown +import androidx.compose.material3.Badge +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.rotate +import androidx.compose.ui.unit.dp + +/** + * Expandable section header with a chevron that rotates on toggle. + * + * @param title Section label text. + * @param count Optional count badge shown next to the title. + * @param isExpanded Whether the section content is visible. + * @param onToggle Called when the user taps the header. + */ +@Composable +fun CollapsibleSectionHeader( + title: String, + count: Int? = null, + isExpanded: Boolean, + onToggle: () -> Unit, +) { + val rotation by animateFloatAsState( + targetValue = if (isExpanded) 180f else 0f, + label = "chevron", + ) + Row( + modifier = Modifier + .fillMaxWidth() + .clickable { onToggle() } + .padding(horizontal = 16.dp, vertical = 8.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(8.dp), + ) { + Text( + text = title, + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.primary, + ) + if (count != null && count > 0) { + Badge(containerColor = MaterialTheme.colorScheme.secondaryContainer) { + Text( + text = count.toString(), + color = MaterialTheme.colorScheme.onSecondaryContainer, + style = MaterialTheme.typography.labelSmall, + ) + } + } + } + Icon( + imageVector = Icons.Default.KeyboardArrowDown, + contentDescription = if (isExpanded) "Collapse" else "Expand", + modifier = Modifier.rotate(rotation), + tint = MaterialTheme.colorScheme.primary, + ) + } +} diff --git a/docs/agents/model-availability-ux-patterns.md b/docs/agents/model-availability-ux-patterns.md new file mode 100644 index 000000000..d1cc08877 --- /dev/null +++ b/docs/agents/model-availability-ux-patterns.md @@ -0,0 +1,29 @@ +# Model Availability State — Canonical Reference + +## 4-state machine + +|State|Meaning|User action| +|---|---|---| +|Ready|Model is on disk and ready|None| +|Preparing|Download in progress or auto-queued|Cancel (user-initiated only)| +|Action Required|User must sign in, accept licence, etc.|Primary action button| +|Unavailable|Cannot be used (denied, unsupported, etc.)|Informational only| + +## Truth table + +See `DownloadState.kt` → `DownloadStateMapper.kt` for the full mapping. + +## File locations + +|File|Purpose| +|---|---| +|`core/model-availability/.../ModelAvailabilityState.kt`|Sealed class + subtypes| +|`core/model-availability/.../DownloadStateMapper.kt`|`DownloadState.toAvailability()`| +|`core/model-availability/.../StateBadge.kt`|Composable badge chip| +|`core/model-availability/.../ModelCard.kt`|ModelCard + ModelCardCompact| +|`core/model-availability/.../GatedModelStatus.kt`|Enum for gated model access| +|`core/model-availability/.../GatedModelStatusRepository.kt`|DataStore repository| + +## Commit ordering (feature/model-availability-ux) + +See `git log` — 16 commits, each atomic and CI-green. diff --git a/docs/manual-test-scenarios-1067.md b/docs/manual-test-scenarios-1067.md new file mode 100644 index 000000000..4cf346147 --- /dev/null +++ b/docs/manual-test-scenarios-1067.md @@ -0,0 +1,206 @@ +# Manual Device Test Scenarios — PR #1067 + +## Prerequisites +- Device connected via ADB (S23 Ultra or equivalent with GPU inference) +- Fresh install recommended for baseline, then incremental tests +- `adb logcat -s KernelAI` for download state logging + +--- + +### 1. Fresh install — auto-queue flow +**Steps:** +1. Clear app data or fresh install +2. Launch app → observe onboarding screen +3. Wait for download manager init (~2s) + +**Expected:** +- Chat onboarding shows `ModelCardCompact` for each auto-queued model +- Each card shows the correct state badge: + - `Preparing (Waiting)` for auto-queued models not yet started + - `Preparing (Downloading)` once download begins +- Progress updates in real time (no stale 0% stuck) +- Model count summary "X of Y models ready" on Settings screen starts at "0 of Y" + +--- + +### 2. Model Management screen — state badges +**Steps:** +1. Navigate to Settings → Model availability (or direct route) +2. Observe each model card + +**Expected:** +- Downloaded models show `Ready` badge (green `CheckCircle`) +- Downloading models show `Preparing` (amber `HourglassEmpty`) +- `Preparing` shows `Downloading` label for user-initiated, `Waiting` for auto-queued +- Error state shows `Failed` badge (red `WarningAmber`) +- Gated models not yet authenticated show `Sign in` badge +- HuggingFace row is **removed** from Model Management (moved to account section) +- Deprecated model (SM8550) is not visible in the list + +--- + +### 3. Action buttons on ModelCard +**Steps:** +1. For a `NotDownloaded` model — tap the model card's button +2. For a `Downloading` model — tap Cancel +3. For a `DownloadFailed` model — tap Retry + +**Expected:** +- `NotDownloaded` → tap starts download, badge transitions to `Preparing` +- `Downloading` → Cancel stops the download (unless `isRequired`) +- `DownloadFailed` → Retry restarts download +- Required models (`isRequired = true`) — Cancel button is **not shown** +- Action button is full-width `Button` (filled) for actionable states, `OutlinedButton` for Unavailable + +--- + +### 4. Chat onboarding — ModelCardCompact integration +**Steps:** +1. Fresh install (or delete models and restart) +2. Observe the onboarding progress section + +**Expected:** +- Each model shows `ModelCardCompact` with: + - Model name (left-aligned) + - Size label / description + - State badge (right-aligned) +- Lock icon shown for gated models +- Tapping "Manage models" navigates to Model Management screen + +--- + +### 5. Voice screen — ModelCardCompact for voices and STT +**Steps:** +1. Navigate to Settings → Voice +2. Expand Sherpa-ONNX section +3. Observe STT model cards +4. Observe voice model cards (Sherpa Piper, Kokoro) + +**Expected:** +- Each STT engine shows `ModelCardCompact` with state badge +- Sherpa Piper voices show `ModelCardCompact` + radio button for selection +- Kokoro voices show `ModelCardCompact` + radio button +- Downloaded voices show `Ready` badge, radio button enabled +- Not-downloaded voices show `Not available` badge, radio button disabled +- Downloading voices show `Preparing` badge with progress +- `ModelCardCompact` has NO action buttons (consistent with design) +- "Manage voice models" `TextButton` at bottom of each section navigates to Model Management + +--- + +### 6. Settings screen — model availability summary +**Steps:** +1. Navigate to Settings +2. Observe the new "Model availability" row + +**Expected:** +- Row shows `AvailabilitySummary` string: "X of Y models ready" +- Count matches observed states: + - `Ready` + `Unavailable` = ready count (unavailable models aren't actionable) +- Tapping row navigates to Model Management screen +- HuggingFace account row is **removed** from Settings (was previously grouped) + +--- + +### 7. Model Settings screen — StateBadge on model cards +**Steps:** +1. Navigate to Settings → Model settings (or conversation model settings) +2. Observe E2B and E4B card headers + +**Expected:** +- Each card header shows `StateBadge` next to model name +- Badge reflects current download/availability state +- Badge updates live as download state changes + +--- + +### 8. Cancel download guard — required models +**Steps:** +1. While a required model (e.g. E2B or E4B) is downloading +2. Try to cancel it from the UI + +**Expected:** +- Cancel button is **not shown** for required models +- If cancellation is attempted programmatically, `cancelDownload()` logs a warning and returns without cancelling +- `isRequired` guard covers both UI and programmatic paths + +--- + +### 9. HuggingFace auth — gated model states +**Steps:** +1. Without HF auth, observe a gated model (e.g. EmbeddingGemma-300M) +2. Sign in to HuggingFace +3. Check gated model status after sign-in + +**Expected:** +- Without auth: gated model shows `ActionRequired (SignInRequired)` badge +- "Sign in to HuggingFace" button appears on ModelCard +- After auth + approval: badge transitions appropriately +- `GatedModelStatusRepository` persists status across app restarts + +--- + +### 10. Deprecated model — SM8550 hidden +**Steps:** +1. Navigate to Model Management +2. Search for "SM8550" in the list + +**Expected:** +- `EMBEDDING_GEMMA_300M_SM8550` is not shown in Model Management +- Model is marked `isDeprecated = true` in code +- Existing download is not deleted (must be manually removed via storage settings) +- Deprecated model is excluded from `preferredForTier` matching + +--- + +### 11. State survival across config changes +**Steps:** +1. Start a download +2. Rotate the device (or trigger config change) +3. Observe all screens + +**Expected:** +- Download progress survives rotation (ViewModel + WorkManager) +- State badges remain correct after rotation +- No Compose recomposition crashes or NPEs +- DataStore-backed `GatedModelStatusRepository` state persists + +--- + +### 12. CollapsibleSectionHeader — memory screen extraction +**Steps:** +1. Navigate to Settings → Memory +2. Observe section headers + +**Expected:** +- `CollapsibleSectionHeader` renders correctly (same visual as before) +- Chevron rotates on expand/collapse +- Count badge shows correct count where applicable +- No regression from extraction to shared `:core:ui` module + +--- + +### 13. Navigation — Model Management route +**Steps:** +1. From Chat onboarding → tap "Manage models" → verify navigation +2. From Settings → tap "Model availability" row → verify navigation +3. From Voice screen → tap "Manage voice models" → verify navigation +4. Press back from Model Management → verify correct return screen + +**Expected:** +- All three entry points navigate to Model Management +- Back navigation returns to the correct previous screen +- No double-navigation or crash + +--- + +### 14. Regression check — existing download states unchanged +**Steps:** +1. Install app with models already downloaded +2. Launch app + +**Expected:** +- Downloaded models show `Ready` badge immediately +- No unnecessary re-downloads triggered +- `isDownloaded()` check prevents re-queuing +- Bundled models (`MINI_LM`) show `Ready` badge even without download diff --git a/feature/chat/build.gradle.kts b/feature/chat/build.gradle.kts index 44d8dc500..096f12a49 100644 --- a/feature/chat/build.gradle.kts +++ b/feature/chat/build.gradle.kts @@ -42,6 +42,7 @@ dependencies { implementation(project(":core:memory")) implementation(project(":core:voice")) implementation(project(":core:skills")) + implementation(project(":core:model-availability")) // LiteRT-LM — needed to resolve ToolProvider / ToolSet types at compile time implementation(libs.litertlm.android) diff --git a/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatScreen.kt b/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatScreen.kt index 8cc25499f..6d6f5f24e 100644 --- a/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatScreen.kt +++ b/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatScreen.kt @@ -148,7 +148,10 @@ import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.compose.collectAsStateWithLifecycle import com.kernel.ai.feature.chat.R import com.kernel.ai.core.inference.download.DownloadState +import com.kernel.ai.core.inference.download.DownloadSource import com.kernel.ai.core.inference.download.KernelModel +import com.kernel.ai.core.model.availability.ModelCardCompact +import com.kernel.ai.core.model.availability.toAvailability import com.kernel.ai.core.skills.mealplan.MealPlannerActivity import com.kernel.ai.core.skills.mealplan.MealPlannerActivityState import com.kernel.ai.core.skills.mealplan.MealPlannerSuggestion @@ -202,6 +205,7 @@ fun ChatScreen( onNewConversation: () -> Unit = {}, onNavigateToList: () -> Unit = {}, onNavigateToSettings: () -> Unit = {}, + onNavigateToModelManagement: () -> Unit = {}, viewModel: ChatViewModel = hiltViewModel(), ) { val uiState by viewModel.uiState.collectAsStateWithLifecycle() @@ -251,7 +255,9 @@ fun ChatScreen( isDownloading = state.isDownloading, modelProgress = state.modelProgress, onRetry = viewModel::retryDownload, - onNavigateToSettings = onNavigateToSettings, + onNavigateToModelManagement = onNavigateToModelManagement, + hfAuthenticated = state.hfAuthenticated, + downloadSources = state.downloadSources, ) is ChatUiState.Ready -> { val context = LocalContext.current @@ -1675,9 +1681,11 @@ private fun LoadingContent() { @Composable private fun OnboardingContent( isDownloading: Boolean, - modelProgress: List, + modelProgress: List, onRetry: (KernelModel) -> Unit, - onNavigateToSettings: () -> Unit, + onNavigateToModelManagement: () -> Unit, + hfAuthenticated: Boolean = false, + downloadSources: Map = emptyMap(), ) { Box(modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { Column( @@ -1710,9 +1718,25 @@ private fun OnboardingContent( verticalArrangement = Arrangement.spacedBy(20.dp), ) { modelProgress.forEach { item -> - ModelProgressRow(item, onRetry = onRetry, onNavigateToSettings = onNavigateToSettings) + val source = downloadSources[item.model] + ModelCardCompact( + title = item.displayName, + description = item.sizeLabel, + state = item.state.toAvailability( + model = item.model, + hfAuth = hfAuthenticated, + source = source ?: DownloadSource.USER_INITIATED, + ), + showLock = item.model.isGated, + ) } } + TextButton( + onClick = onNavigateToModelManagement, + modifier = Modifier.padding(top = 8.dp), + ) { + Text("Manage models") + } } else if (isDownloading) { CircularProgressIndicator(modifier = Modifier.padding(top = 24.dp)) } @@ -1720,128 +1744,7 @@ private fun OnboardingContent( } } -@Composable -private fun ModelProgressRow( - item: ModelDownloadProgress, - onRetry: (KernelModel) -> Unit, - onNavigateToSettings: () -> Unit, -) { - val state = item.state - Column(modifier = Modifier.fillMaxWidth()) { - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.SpaceBetween, - verticalAlignment = Alignment.CenterVertically, - ) { - Text( - text = item.displayName, - style = MaterialTheme.typography.bodyMedium, - ) - Text( - text = when (state) { - is DownloadState.Downloaded -> "✓ Ready" - is DownloadState.Downloading -> { - val pct = (state.progress * 100).toInt() - if (state.bytesPerSecond > 0) { - val mbps = state.bytesPerSecond / 1_048_576.0 - "$pct% · ${"%.1f".format(mbps)} MB/s" - } else "$pct%" - } - is DownloadState.Error -> "Error" - is DownloadState.NotDownloaded -> item.sizeLabel - }, - style = MaterialTheme.typography.bodySmall, - color = when (state) { - is DownloadState.Downloaded -> MaterialTheme.colorScheme.primary - is DownloadState.Error -> MaterialTheme.colorScheme.error - else -> MaterialTheme.colorScheme.onSurfaceVariant - }, - ) - } - - Spacer(modifier = Modifier.height(6.dp)) - - when (state) { - is DownloadState.Downloading -> { - LinearProgressIndicator( - progress = { state.progress }, - modifier = Modifier.fillMaxWidth().height(6.dp), - ) - if (state.remainingMs > 0) { - val etaText = formatEta(state.remainingMs) - Text( - text = etaText, - style = MaterialTheme.typography.labelSmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.padding(top = 2.dp), - ) - } - } - is DownloadState.Downloaded -> { - LinearProgressIndicator( - progress = { 1f }, - modifier = Modifier.fillMaxWidth().height(6.dp), - ) - } - is DownloadState.Error -> { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(8.dp), - ) { - Text( - text = "Download failed", - style = MaterialTheme.typography.labelSmall, - color = MaterialTheme.colorScheme.error, - modifier = Modifier.weight(1f), - ) - Button( - onClick = { onRetry(item.model) }, - contentPadding = PaddingValues(horizontal = 12.dp, vertical = 4.dp), - ) { - Text("Retry", style = MaterialTheme.typography.labelMedium) - } - } - } - is DownloadState.NotDownloaded -> { - if (item.model.isGated) { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(8.dp), - ) { - Text( - text = "Sign in to HuggingFace to download", - style = MaterialTheme.typography.labelSmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.weight(1f), - ) - Button( - onClick = onNavigateToSettings, - contentPadding = PaddingValues(horizontal = 12.dp, vertical = 4.dp), - ) { - Text("Sign in", style = MaterialTheme.typography.labelMedium) - } - } - } else { - Text( - text = "Queued", - style = MaterialTheme.typography.labelSmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.padding(top = 2.dp), - ) - } - } - } - } -} -private fun formatEta(remainingMs: Long): String { - val totalSecs = remainingMs / 1000 - return when { - totalSecs < 60 -> "~${totalSecs}s remaining" - totalSecs < 3600 -> "~${totalSecs / 60}m ${totalSecs % 60}s remaining" - else -> "~${totalSecs / 3600}h ${(totalSecs % 3600) / 60}m remaining" - } -} @Composable private fun ToolCallChip(toolCall: ToolCallInfo, modifier: Modifier = Modifier) { diff --git a/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatViewModel.kt b/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatViewModel.kt index fb71845c3..25f258e7b 100644 --- a/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatViewModel.kt +++ b/feature/chat/src/main/java/com/kernel/ai/feature/chat/ChatViewModel.kt @@ -24,8 +24,10 @@ import com.kernel.ai.core.inference.ModelConfig import com.kernel.ai.core.inference.PersonaMode import com.kernel.ai.core.inference.capabilities import com.kernel.ai.core.inference.download.DownloadState +import com.kernel.ai.core.inference.download.DownloadSource import com.kernel.ai.core.inference.download.KernelModel import com.kernel.ai.core.inference.download.ModelDownloadManager +import com.kernel.ai.core.inference.auth.HuggingFaceAuthRepository import com.kernel.ai.core.inference.hardware.HardwareTier import com.kernel.ai.core.memory.rag.RagRepository import com.kernel.ai.core.memory.repository.ConversationRepository @@ -135,6 +137,7 @@ class ChatViewModel @Inject constructor( private val jandalPersona: JandalPersona, private val nzTruthSeedingService: NzTruthSeedingService, private val verboseLoggingPreferenceUseCase: com.kernel.ai.core.memory.usecase.VerboseLoggingPreferenceUseCase, + private val authRepository: HuggingFaceAuthRepository, private val startListeningCuePlayer: StartListeningCuePlayer, private val chatPreferences: ChatPreferences, ) : ViewModel() { @@ -360,15 +363,27 @@ class ChatViewModel @Inject constructor( ) { messages, inputText, error, title, isSpeakingResponse -> InputState(messages, inputText, error, title, isSpeakingResponse) } - - /** Base uiState without visual prefs (5-input combine). */ + /** Base uiState without visual prefs (7-input combine). */ private val baseUiState: StateFlow = combine( engineState, downloadManager.downloadStates, + downloadManager.downloadSources, inputState, _showThinkingProcess, isArchived, - ) { engine, downloadStates, input, showThinking, archived -> + authRepository.isAuthenticated, + ) { array -> + @Suppress("UNCHECKED_CAST") + val engine = array[0] as EngineState + @Suppress("UNCHECKED_CAST") + val downloadStates = array[1] as Map + @Suppress("UNCHECKED_CAST") + val downloadSources = array[2] as Map + @Suppress("UNCHECKED_CAST") + val input = array[3] as InputState + val showThinking = array[4] as Boolean + val archived = array[5] as Boolean + val hfAuth = array[6] as Boolean val allDownloaded = downloadManager.areRequiredModelsDownloaded() val tier = downloadManager.deviceTier val displayModels: List = if (tier == HardwareTier.FLAGSHIP) { @@ -389,7 +404,12 @@ class ChatViewModel @Inject constructor( state = downloadStates[model] ?: DownloadState.NotDownloaded, ) } - ChatUiState.ModelsNotReady(isDownloading = anyDownloading, modelProgress = progress) + ChatUiState.ModelsNotReady( + isDownloading = anyDownloading, + modelProgress = progress, + hfAuthenticated = hfAuth, + downloadSources = downloadSources, + ) } // Archived conversations are read-only — no engine needed. Skip the isReady gate. !archived && (!engine.isReady || !engine.conversationInitialized) -> ChatUiState.Loading diff --git a/feature/chat/src/main/java/com/kernel/ai/feature/chat/model/ChatUiState.kt b/feature/chat/src/main/java/com/kernel/ai/feature/chat/model/ChatUiState.kt index d385a4d51..03916ff51 100644 --- a/feature/chat/src/main/java/com/kernel/ai/feature/chat/model/ChatUiState.kt +++ b/feature/chat/src/main/java/com/kernel/ai/feature/chat/model/ChatUiState.kt @@ -1,6 +1,7 @@ package com.kernel.ai.feature.chat.model import com.kernel.ai.core.inference.ModelCapabilities +import com.kernel.ai.core.inference.download.DownloadSource import com.kernel.ai.core.inference.download.DownloadState import com.kernel.ai.core.inference.download.KernelModel @@ -51,6 +52,8 @@ sealed interface ChatUiState { val isDownloading: Boolean, /** Per-model download progress, ordered by priority (required first). */ val modelProgress: List = emptyList(), + val hfAuthenticated: Boolean = false, + val downloadSources: Map = emptyMap(), ) : ChatUiState data class ModelDownloadProgress( diff --git a/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelInitTest.kt b/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelInitTest.kt index 21dc12a60..8cbf25ecd 100644 --- a/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelInitTest.kt +++ b/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelInitTest.kt @@ -43,6 +43,8 @@ import com.kernel.ai.core.voice.VoiceInputController import com.kernel.ai.core.voice.VoiceOutputController import com.kernel.ai.core.voice.VoiceOutputPreferences import com.kernel.ai.core.voice.StartListeningCuePlayer +import com.kernel.ai.core.inference.auth.HuggingFaceAuthRepository +import com.kernel.ai.core.memory.prefs.ChatPreferences import io.mockk.coEvery import io.mockk.coVerify import io.mockk.every @@ -96,7 +98,8 @@ class ChatViewModelInitTest { private val nzTruthSeedingService: NzTruthSeedingService = mockk(relaxed = true) private val verboseLoggingPreferenceUseCase: VerboseLoggingPreferenceUseCase = mockk(relaxed = true) private val startListeningCuePlayer: StartListeningCuePlayer = mockk(relaxed = true) - private val chatPreferences: com.kernel.ai.core.memory.prefs.ChatPreferences = mockk(relaxed = true) + private val authRepository: HuggingFaceAuthRepository = mockk(relaxed = true) + private val chatPreferences: ChatPreferences = mockk(relaxed = true) @BeforeEach fun setUp() { @@ -156,6 +159,7 @@ class ChatViewModelInitTest { @Test fun `fresh chat initialization resets inherited inference session`() = runTest(dispatcher) { ChatViewModel(savedStateHandle = SavedStateHandle(), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -191,6 +195,7 @@ class ChatViewModelInitTest { @Test fun `restored chat initialization does not reset current inference session`() = runTest(dispatcher) { ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -226,6 +231,7 @@ class ChatViewModelInitTest { @Test fun `closing chat never shuts down process scoped inference engine`() = runTest(dispatcher) { val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -268,6 +274,7 @@ class ChatViewModelInitTest { coEvery { mealPlannerCoordinator.activeSessionReply("conv-existing") } returns MealPlannerReply(prompt) val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -315,6 +322,7 @@ class ChatViewModelInitTest { coEvery { mealPlannerCoordinator.activeSessionActivity("conv-existing") } returns activity val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -393,6 +401,7 @@ class ChatViewModelInitTest { coEvery { mealPlanSessionRepository.getActiveSession("conv-existing") } returns snapshot val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -445,6 +454,7 @@ class ChatViewModelInitTest { ) val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -514,6 +524,7 @@ class ChatViewModelInitTest { ) val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -571,6 +582,7 @@ class ChatViewModelInitTest { ) val viewModel = ChatViewModel(savedStateHandle = savedStateHandle, chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -630,6 +642,7 @@ class ChatViewModelInitTest { QuickIntentRouter.RouteResult.FallThrough(input = "and bred to my last") val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("minimalContext" to true)), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -679,6 +692,7 @@ class ChatViewModelInitTest { ) val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -744,6 +758,7 @@ class ChatViewModelInitTest { ), ) val viewModel = ChatViewModel(savedStateHandle = savedStateHandle, chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -800,6 +815,7 @@ class ChatViewModelInitTest { every { quickIntentRouter.route(any()) } returns QuickIntentRouter.RouteResult.FallThrough(input = "hello") val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, @@ -852,6 +868,7 @@ class ChatViewModelInitTest { every { quickIntentRouter.route(any()) } returns QuickIntentRouter.RouteResult.FallThrough(input = "hello") val viewModel = ChatViewModel(savedStateHandle = SavedStateHandle(), chatPreferences = chatPreferences, + authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, diff --git a/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelVoiceTest.kt b/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelVoiceTest.kt index b12b19679..d36a9eea4 100644 --- a/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelVoiceTest.kt +++ b/feature/chat/src/test/java/com/kernel/ai/feature/chat/ChatViewModelVoiceTest.kt @@ -52,6 +52,8 @@ import com.kernel.ai.core.voice.VoiceOutputPreferences import com.kernel.ai.core.voice.VoiceOutputResult import com.kernel.ai.core.voice.VoiceSpeakRequest import com.kernel.ai.core.voice.VoiceOutputStreamingSession +import com.kernel.ai.core.inference.auth.HuggingFaceAuthRepository +import com.kernel.ai.core.memory.prefs.ChatPreferences import com.kernel.ai.feature.chat.model.ChatUiState import io.mockk.coEvery import io.mockk.coVerify @@ -112,7 +114,8 @@ class ChatViewModelVoiceTest { private val nzTruthSeedingService: NzTruthSeedingService = mockk(relaxed = true) private val verboseLoggingPreferenceUseCase: VerboseLoggingPreferenceUseCase = mockk(relaxed = true) private val startListeningCuePlayer: StartListeningCuePlayer = mockk(relaxed = true) - private val chatPreferences: com.kernel.ai.core.memory.prefs.ChatPreferences = mockk(relaxed = true) + private val chatPreferences: ChatPreferences = mockk(relaxed = true) + private val authRepository: HuggingFaceAuthRepository = mockk(relaxed = true) private val voiceInputEvents = MutableSharedFlow() private val voiceOutputEvents = MutableSharedFlow() @@ -1028,7 +1031,7 @@ class ChatViewModelVoiceTest { } - private fun createViewModel(): ChatViewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, + private fun createViewModel(): ChatViewModel = ChatViewModel(savedStateHandle = SavedStateHandle(mapOf("conversationId" to "conv-existing")), chatPreferences = chatPreferences, authRepository = authRepository, inferenceEngine = inferenceEngine, downloadManager = downloadManager, conversationRepository = conversationRepository, diff --git a/feature/settings/build.gradle.kts b/feature/settings/build.gradle.kts index 6ed4ca31c..bca1a9f98 100644 --- a/feature/settings/build.gradle.kts +++ b/feature/settings/build.gradle.kts @@ -40,6 +40,7 @@ dependencies { implementation(project(":core:memory")) implementation(project(":core:voice")) implementation(project(":core:skills")) + implementation(project(":core:model-availability")) // Compose implementation(platform(libs.compose.bom)) diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/MemoryScreen.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/MemoryScreen.kt index ec2e8d695..d866ef2d4 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/MemoryScreen.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/MemoryScreen.kt @@ -1,6 +1,5 @@ package com.kernel.ai.feature.settings -import androidx.compose.animation.core.animateFloatAsState import androidx.compose.foundation.ExperimentalFoundationApi import androidx.compose.foundation.clickable import androidx.compose.foundation.combinedClickable @@ -21,7 +20,6 @@ import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Clear import androidx.compose.material.icons.filled.Delete -import androidx.compose.material.icons.filled.KeyboardArrowDown import androidx.compose.material.icons.filled.Search import androidx.compose.material3.AlertDialog import androidx.compose.material3.Badge @@ -45,17 +43,17 @@ import androidx.compose.material3.TopAppBar import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember +import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.compose.ui.draw.rotate import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.compose.collectAsStateWithLifecycle +import com.kernel.ai.core.ui.CollapsibleSectionHeader import com.kernel.ai.core.memory.entity.CoreMemoryEntity import com.kernel.ai.core.memory.entity.EpisodicMemoryEntity import com.kernel.ai.core.memory.entity.KiwiMemoryEntity @@ -653,50 +651,6 @@ private fun SectionHeader(title: String) { ) } -@Composable -private fun CollapsibleSectionHeader( - title: String, - count: Int? = null, - isExpanded: Boolean, - onToggle: () -> Unit, -) { - val rotation by animateFloatAsState(targetValue = if (isExpanded) 180f else 0f, label = "chevron") - Row( - modifier = Modifier - .fillMaxWidth() - .clickable { onToggle() } - .padding(horizontal = 16.dp, vertical = 8.dp), - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.SpaceBetween, - ) { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(8.dp), - ) { - Text( - text = title, - style = MaterialTheme.typography.labelMedium, - color = MaterialTheme.colorScheme.primary, - ) - if (count != null && count > 0) { - Badge(containerColor = MaterialTheme.colorScheme.secondaryContainer) { - Text( - text = count.toString(), - color = MaterialTheme.colorScheme.onSecondaryContainer, - style = MaterialTheme.typography.labelSmall, - ) - } - } - } - Icon( - imageVector = Icons.Default.KeyboardArrowDown, - contentDescription = if (isExpanded) "Collapse" else "Expand", - modifier = Modifier.rotate(rotation), - tint = MaterialTheme.colorScheme.primary, - ) - } -} - /** * Section header for Core Memories. In normal mode shows just the label. * In selection mode shows "Select All", "Delete Selected (N)", and "Cancel" actions. diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementScreen.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementScreen.kt index 881ecd210..82089e017 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementScreen.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementScreen.kt @@ -3,7 +3,6 @@ package com.kernel.ai.feature.settings import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize @@ -17,24 +16,17 @@ import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.filled.ArrowBack -import androidx.compose.material.icons.filled.AccountCircle -import androidx.compose.material.icons.filled.CheckCircle -import androidx.compose.material.icons.filled.Lock import androidx.compose.material3.Button -import androidx.compose.material3.ButtonDefaults import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.Icon import androidx.compose.material3.IconButton -import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.ListItem import androidx.compose.material3.MaterialTheme import androidx.compose.material3.RadioButton import androidx.compose.material3.Scaffold -import androidx.compose.material3.SuggestionChip -import androidx.compose.material3.SuggestionChipDefaults import androidx.compose.material3.Text import androidx.compose.material3.TextButton import androidx.compose.material3.TopAppBar @@ -46,7 +38,6 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.core.net.toUri import androidx.browser.customtabs.CustomTabsIntent -import androidx.compose.ui.graphics.Color import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel @@ -54,9 +45,10 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle import com.kernel.ai.core.inference.PersonaMode import com.kernel.ai.core.inference.download.DownloadState import com.kernel.ai.core.inference.download.KernelModel +import com.kernel.ai.core.model.availability.ModelCard +import com.kernel.ai.core.model.availability.ModelAvailabilityState +import com.kernel.ai.core.model.availability.toAvailability -private val HfOrange = Color(0xFFFF9D00) -private const val EMBEDDING_GEMMA_LICENCE_URL = "https://huggingface.co/litert-community/embeddinggemma-300m" @OptIn(ExperimentalMaterial3Api::class) @Composable @@ -70,12 +62,10 @@ fun ModelManagementScreen( val listState = rememberLazyListState() // Scroll to "Conversation model" section when requested (e.g. from Settings "Preferred model" item). - // Wait until models are loaded so the item count is accurate. - val visibleModelCount = uiState.models.count { it.model != KernelModel.EMBEDDING_GEMMA_300M_SM8550 } + val visibleModelCount = uiState.models.size LaunchedEffect(scrollToConversationModel, visibleModelCount) { if (scrollToConversationModel && visibleModelCount > 0) { - // Layout: 0=storage, 1=HF account, 2=Models header, 3..3+N-1=model rows, 3+N=Conversation model header - listState.animateScrollToItem(index = 3 + visibleModelCount) + listState.animateScrollToItem(index = 2 + visibleModelCount) } } @@ -105,26 +95,6 @@ fun ModelManagementScreen( modifier = Modifier.padding(16.dp), ) } - - // ── HuggingFace account ─────────────────────────────────────────── - item { - Text( - text = "HuggingFace Account", - style = MaterialTheme.typography.labelMedium, - color = MaterialTheme.colorScheme.primary, - modifier = Modifier.padding(horizontal = 16.dp, vertical = 4.dp), - ) - HuggingFaceRow( - isAuthenticated = uiState.hfAuthenticated, - username = uiState.hfUsername, - onSignIn = { viewModel.startAuth() }, - onSignOut = { viewModel.signOut() }, - onViewLicence = { openInAppBrowser(context, EMBEDDING_GEMMA_LICENCE_URL) }, - ) - HorizontalDivider() - Spacer(modifier = Modifier.height(8.dp)) - } - // ── Model rows ──────────────────────────────────────────────────── item { Text( @@ -135,20 +105,54 @@ fun ModelManagementScreen( ) } - // Skip EMBEDDING_GEMMA_300M_SM8550 (disabled variant) - val visibleModels = uiState.models.filter { it.model != KernelModel.EMBEDDING_GEMMA_300M_SM8550 } + + // Skip EMBEDDING_GEMMA_300M_SM8550 (already filtered by isDeprecated in VM) + val visibleModels = uiState.models items(visibleModels) { rowState -> - ModelRow( - rowState = rowState, - isAuthenticated = uiState.hfAuthenticated, - onDownload = { viewModel.downloadModel(rowState.model) }, - onCancel = { viewModel.cancelDownload(rowState.model) }, - onUpdate = { viewModel.updateModel(rowState.model) }, - onDelete = { viewModel.deleteModel(rowState.model) }, - onViewLicence = { url -> openInAppBrowser(context, url) }, - onRetry = { viewModel.downloadModel(rowState.model) }, + val availabilityState = rowState.downloadState.toAvailability( + model = rowState.model, + hfAuth = uiState.hfAuthenticated, + source = rowState.downloadSource, ) - HorizontalDivider() + val canDelete = availabilityState is ModelAvailabilityState.Ready && + !rowState.model.isBundled && + rowState.model != uiState.preferredModel + ModelCard( + title = rowState.model.displayName, + description = "%.1f MB".format(rowState.model.approxSizeBytes / 1_000_000f), + state = availabilityState, + showLock = rowState.model.isGated && rowState.downloadState is DownloadState.NotDownloaded, + onPrimaryAction = { + when (val state = rowState.downloadState) { + is DownloadState.Downloading -> viewModel.cancelDownload(rowState.model) + is DownloadState.Downloaded -> viewModel.updateModel(rowState.model) + is DownloadState.NotDownloaded -> { + if (!uiState.hfAuthenticated && rowState.model.isGated) { + viewModel.startAuth() + } else { + viewModel.downloadModel(rowState.model) + } + } + is DownloadState.Error -> { + if (state.licenceRequired) { + rowState.model.licenceUrl?.let { url -> + CustomTabsIntent.Builder().build().launchUrl(context, url.toUri()) + } + } else { + viewModel.downloadModel(rowState.model) + } + } + } + }, + onSecondaryAction = if (canDelete) { + { viewModel.deleteModel(rowState.model) } + } else { + null + }, + secondaryActionLabel = if (canDelete) "Delete" else null, + modifier = Modifier.padding(horizontal = 16.dp, vertical = 4.dp), + ) + Spacer(modifier = Modifier.height(4.dp)) } // ── Preferred model section ─────────────────────────────────────── @@ -342,211 +346,6 @@ private fun StorageSummaryCard( } } -@Composable -private fun HuggingFaceRow( - isAuthenticated: Boolean, - username: String?, - onSignIn: () -> Unit, - onSignOut: () -> Unit, - onViewLicence: () -> Unit, - modifier: Modifier = Modifier, -) { - if (isAuthenticated) { - ListItem( - modifier = modifier.fillMaxWidth(), - headlineContent = { - Text(if (username != null) "@$username" else "Signed in") - }, - supportingContent = { - Column { - Text("Gated models unlocked") - TextButton(onClick = onViewLicence, contentPadding = PaddingValues(0.dp)) { - Text("View licence →", style = MaterialTheme.typography.bodySmall) - } - } - }, - leadingContent = { - Icon(Icons.Default.AccountCircle, contentDescription = null, tint = HfOrange) - }, - trailingContent = { - TextButton(onClick = onSignOut) { - Text("Sign out", color = MaterialTheme.colorScheme.error) - } - }, - ) - } else { - ListItem( - modifier = modifier.fillMaxWidth(), - headlineContent = { Text("Not signed in") }, - supportingContent = { - Column { - Text("Required to download gated Hugging Face models. Accept licence before downloading.") - TextButton(onClick = onViewLicence, contentPadding = PaddingValues(0.dp)) { - Text("View licence →", style = MaterialTheme.typography.bodySmall) - } - } - }, - leadingContent = { - Icon(Icons.Default.AccountCircle, contentDescription = null, tint = MaterialTheme.colorScheme.onSurfaceVariant) - }, - trailingContent = { - Button( - onClick = onSignIn, - colors = ButtonDefaults.buttonColors(containerColor = HfOrange), - ) { - Text("Sign in", color = Color.Black) - } - }, - ) - } -} - -@Composable -private fun ModelRow( - rowState: ModelRowState, - isAuthenticated: Boolean, - onDownload: () -> Unit, - onCancel: () -> Unit, - onUpdate: () -> Unit, - onDelete: () -> Unit, - onViewLicence: (String) -> Unit, - onRetry: () -> Unit, - modifier: Modifier = Modifier, -) { - val model = rowState.model - val state = rowState.downloadState - - ListItem( - modifier = modifier.fillMaxWidth(), - headlineContent = { - Row(verticalAlignment = Alignment.CenterVertically) { - Text(model.displayName) - if (model.isGated) { - Spacer(modifier = Modifier.width(4.dp)) - Icon( - Icons.Default.Lock, - contentDescription = "Gated", - modifier = Modifier.size(14.dp), - tint = HfOrange, - ) - } - Spacer(modifier = Modifier.width(6.dp)) - if (model.isRequired) { - SuggestionChip( - onClick = {}, - label = { Text("Required", style = MaterialTheme.typography.labelSmall) }, - colors = SuggestionChipDefaults.suggestionChipColors( - containerColor = MaterialTheme.colorScheme.primaryContainer, - ), - ) - } else { - SuggestionChip( - onClick = {}, - label = { Text("Optional", style = MaterialTheme.typography.labelSmall) }, - ) - } - } - }, - supportingContent = { - Column { - Text( - text = formatBytes(model.approxSizeBytes), - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - when (state) { - is DownloadState.Downloading -> { - Spacer(modifier = Modifier.height(4.dp)) - LinearProgressIndicator( - progress = { state.progress }, - modifier = Modifier.fillMaxWidth(), - ) - val pct = (state.progress * 100).toInt() - val mbps = state.bytesPerSecond / 1_000_000.0 - val etaSec = state.remainingMs / 1000 - Text( - text = buildString { - append("$pct%") - if (state.bytesPerSecond > 0) append(" · ${"%.1f".format(mbps)} MB/s") - if (etaSec > 0) append(" · ${etaSec}s remaining") - }, - style = MaterialTheme.typography.bodySmall, - ) - } - is DownloadState.Error -> { - Text( - text = state.message, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.error, - ) - } - else -> Unit - } - } - }, - trailingContent = { - Row(verticalAlignment = Alignment.CenterVertically) { - when (state) { - is DownloadState.NotDownloaded -> { - val gatedBlocked = model.isGated && !isAuthenticated - TextButton( - onClick = onDownload, - enabled = !gatedBlocked, - ) { - Text("Download") - } - } - is DownloadState.Downloading -> { - TextButton(onClick = onCancel) { - Text("Cancel") - } - } - is DownloadState.Downloaded -> { - if (model.isBundled) { - SuggestionChip( - onClick = {}, - label = { Text("Built-in", style = MaterialTheme.typography.labelSmall) }, - colors = SuggestionChipDefaults.suggestionChipColors( - containerColor = MaterialTheme.colorScheme.secondaryContainer, - ), - ) - } else { - Icon( - Icons.Default.CheckCircle, - contentDescription = "Downloaded", - tint = Color(0xFF4CAF50), - modifier = Modifier.size(20.dp), - ) - Spacer(modifier = Modifier.width(4.dp)) - TextButton(onClick = onUpdate) { - Text("Update") - } - if (!model.isRequired) { - Spacer(modifier = Modifier.width(4.dp)) - TextButton(onClick = onDelete) { - Text("Delete", color = MaterialTheme.colorScheme.error) - } - } - } - } - is DownloadState.Error -> { - Column(horizontalAlignment = Alignment.End) { - val licenceUrl = model.licenceUrl - if (state.licenceRequired && licenceUrl != null) { - TextButton(onClick = { onViewLicence(licenceUrl) }) { - Text("Accept licence") - } - } - TextButton(onClick = onRetry) { - Text("Retry") - } - } - } - } - } - }, - ) -} private fun formatBytes(bytes: Long): String { return when { diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementViewModel.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementViewModel.kt index 28e74fd32..a21419219 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementViewModel.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelManagementViewModel.kt @@ -7,17 +7,26 @@ import androidx.lifecycle.viewModelScope import com.kernel.ai.core.inference.JandalPersona import com.kernel.ai.core.inference.PersonaMode import com.kernel.ai.core.inference.auth.HuggingFaceAuthRepository +import com.kernel.ai.core.inference.download.DownloadSource import com.kernel.ai.core.inference.download.DownloadState import com.kernel.ai.core.inference.download.KernelModel import com.kernel.ai.core.inference.download.ModelDownloadManager import com.kernel.ai.core.inference.download.localFile import com.kernel.ai.core.inference.prefs.ModelPreferences +import com.kernel.ai.core.model.availability.AvailabilitySummary +import com.kernel.ai.core.model.availability.GatedModelStatus +import com.kernel.ai.core.model.availability.GatedModelStatusRepository +import com.kernel.ai.core.model.availability.computeAvailabilitySummary import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.flow.debounce +import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import java.io.IOException @@ -26,6 +35,7 @@ import javax.inject.Inject data class ModelRowState( val model: KernelModel, val downloadState: DownloadState, + val downloadSource: DownloadSource = DownloadSource.USER_INITIATED, ) data class ModelManagementUiState( @@ -36,6 +46,12 @@ data class ModelManagementUiState( val hfUsername: String? = null, val preferredModel: KernelModel? = null, val personaMode: PersonaMode = PersonaMode.HALF, + val availabilitySummary: AvailabilitySummary = AvailabilitySummary(total = 0), +) + +private data class StorageMetrics( + val used: Long = 0, + val free: Long = 0, ) @HiltViewModel @@ -44,29 +60,89 @@ class ModelManagementViewModel @Inject constructor( private val modelPreferences: ModelPreferences, private val authRepository: HuggingFaceAuthRepository, private val jandalPersona: JandalPersona, + private val gatedModelStatusRepository: GatedModelStatusRepository, @ApplicationContext private val context: Context, ) : ViewModel() { - val uiState = combine( + private val _storageMetrics = MutableStateFlow(StorageMetrics()) + + /** Per-model gated status map, collected from DataStore. */ + private val _gatedStatuses = MutableStateFlow>(emptyMap()) + + init { + viewModelScope.launch { + val gatedModels = KernelModel.entries.filter { + it.showInModelManagement && !it.isDeprecated && it.isGated + } + gatedModels.forEach { model -> + launch { + gatedModelStatusRepository.get(model).collect { status -> + _gatedStatuses.update { it.toMutableMap().apply { put(model, status) } } + } + } + } + } + // Compute storage metrics on IO dispatcher, driven by download-state changes + viewModelScope.launch(Dispatchers.IO) { + modelDownloadManager.downloadStates + .debounce(500) + .collect { + val used = calculateStorageUsed() + val free = calculateFreeSpace() + _storageMetrics.value = StorageMetrics(used = used, free = free) + } + } + } + + val uiState: StateFlow = combine( modelDownloadManager.downloadStates, + modelDownloadManager.downloadSources, authRepository.isAuthenticated, authRepository.username, modelPreferences.preferredConversationModel, jandalPersona.personaMode, - ) { downloadStates, hfAuthenticated, hfUsername, preferredModel, personaMode -> + _storageMetrics, + _gatedStatuses, + ) { array -> + @Suppress("UNCHECKED_CAST") + val downloadStates = array[0] as Map + @Suppress("UNCHECKED_CAST") + val downloadSources = array[1] as Map + val hfAuthenticated = array[2] as Boolean + @Suppress("UNCHECKED_CAST") + val hfUsername = array[3] as String? + @Suppress("UNCHECKED_CAST") + val preferredModel = array[4] as KernelModel? + val personaMode = array[5] as PersonaMode + val storage = array[6] as StorageMetrics + val gatedStatuses = array[7] as Map + + val filteredModels = KernelModel.entries.filter { + it.showInModelManagement && !it.isDeprecated + } + val models = filteredModels.map { model -> + ModelRowState( + model = model, + downloadState = downloadStates[model] ?: DownloadState.NotDownloaded, + downloadSource = downloadSources[model] ?: DownloadSource.USER_INITIATED, + ) + } + val summary = computeAvailabilitySummary( + models = filteredModels, + downloadStates = downloadStates, + hfAuth = hfAuthenticated, + downloadSources = downloadSources, + gatedStatuses = gatedStatuses, + ) ModelManagementUiState( - models = KernelModel.entries.filter { it.showInModelManagement }.map { model -> - ModelRowState( - model = model, - downloadState = downloadStates[model] ?: DownloadState.NotDownloaded, - ) - }, - totalStorageUsedBytes = calculateStorageUsed(), - freeSpaceBytes = calculateFreeSpace(), + models = models, + totalStorageUsedBytes = storage.used, + freeSpaceBytes = storage.free, hfAuthenticated = hfAuthenticated, hfUsername = hfUsername, preferredModel = preferredModel, personaMode = personaMode, + availabilitySummary = summary, ) }.stateIn( scope = viewModelScope, @@ -87,10 +163,16 @@ class ModelManagementViewModel @Inject constructor( } fun deleteModel(model: KernelModel) { - if (model.isRequired || model.isBundled) return + if (model.isBundled) return + // Never delete the currently selected conversation model + if (model == uiState.value.preferredModel) return + // Cancel any in-progress download before deleting the file + val currentState = modelDownloadManager.downloadStates.value[model] + if (currentState is DownloadState.Downloading) { + modelDownloadManager.cancelDownload(model) + } viewModelScope.launch(Dispatchers.IO) { model.localFile(context).delete() - // Also delete any stale .tmp resume file val tmpFile = java.io.File(model.localFile(context).absolutePath + ".tmp") if (tmpFile.exists()) tmpFile.delete() withContext(Dispatchers.Main) { diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsScreen.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsScreen.kt index 8c0163c84..35a7c16b3 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsScreen.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsScreen.kt @@ -83,10 +83,11 @@ fun ModelSettingsScreen( .verticalScroll(rememberScrollState()), ) { uiState.e2bSettings?.let { settings -> - ModelCard( + ModelSettingsCard( modelName = "Gemma 4 E-2B", settings = settings, capabilities = KernelModel.GEMMA_4_E2B.capabilities, + state = uiState.e2bAvailability, onSettingsChanged = viewModel::updateE2bSettings, onReset = viewModel::resetE2bToDefaults, ) @@ -95,10 +96,11 @@ fun ModelSettingsScreen( HorizontalDivider(modifier = Modifier.padding(vertical = 8.dp)) uiState.e4bSettings?.let { settings -> - ModelCard( + ModelSettingsCard( modelName = "Gemma 4 E-4B", settings = settings, capabilities = KernelModel.GEMMA_4_E4B.capabilities, + state = uiState.e4bAvailability, onSettingsChanged = viewModel::updateE4bSettings, onReset = viewModel::resetE4bToDefaults, ) @@ -138,10 +140,11 @@ fun ModelSettingsScreen( } @Composable -private fun ModelCard( +private fun ModelSettingsCard( modelName: String, settings: ModelSettingsEntity, capabilities: ModelCapabilities, + state: com.kernel.ai.core.model.availability.ModelAvailabilityState? = null, onSettingsChanged: (ModelSettingsEntity) -> Unit, onReset: () -> Unit, ) { @@ -155,8 +158,11 @@ private fun ModelCard( Text( text = modelName, style = MaterialTheme.typography.titleMedium, - modifier = Modifier.padding(start = 8.dp), + modifier = Modifier.padding(start = 8.dp).weight(1f), ) + if (state != null) { + com.kernel.ai.core.model.availability.StateBadge(state = state) + } } Spacer(modifier = Modifier.height(16.dp)) @@ -384,7 +390,7 @@ private fun ModelSettingsScreenPreview() { topP = 0.95f, speculativeDecodingEnabled = false, ) - ModelCard( + ModelSettingsCard( modelName = "Gemma 4 E-2B", settings = sampleSettings, capabilities = KernelModel.GEMMA_4_E2B.capabilities, diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsViewModel.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsViewModel.kt index f97202d95..33c0c7f2b 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsViewModel.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/ModelSettingsViewModel.kt @@ -4,6 +4,13 @@ import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.kernel.ai.core.inference.download.KernelModel +import com.kernel.ai.core.inference.download.DownloadState +import com.kernel.ai.core.inference.download.DownloadSource +import com.kernel.ai.core.inference.download.ModelDownloadManager +import com.kernel.ai.core.model.availability.ModelAvailabilityState +import com.kernel.ai.core.model.availability.toAvailability +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.combine import com.kernel.ai.core.memory.entity.ModelSettingsEntity import com.kernel.ai.core.memory.repository.ModelSettingsRepository import dagger.hilt.android.lifecycle.HiltViewModel @@ -16,9 +23,9 @@ import javax.inject.Inject @HiltViewModel class ModelSettingsViewModel @Inject constructor( + private val modelDownloadManager: ModelDownloadManager, private val modelSettingsRepository: ModelSettingsRepository, ) : ViewModel() { - data class ModelSettingsUiState( /** Current draft values shown in the UI. Not persisted until [saveSettings] is called. */ val e2bSettings: ModelSettingsEntity? = null, @@ -27,6 +34,8 @@ class ModelSettingsViewModel @Inject constructor( val persistedE2b: ModelSettingsEntity? = null, val persistedE4b: ModelSettingsEntity? = null, val isSaving: Boolean = false, + val e2bAvailability: ModelAvailabilityState? = null, + val e4bAvailability: ModelAvailabilityState? = null, ) { val hasUnsavedChanges: Boolean get() = e2bSettings != persistedE2b || e4bSettings != persistedE4b @@ -37,6 +46,31 @@ class ModelSettingsViewModel @Inject constructor( init { loadSettings() + viewModelScope.launch { + combine( + modelDownloadManager.downloadStates, + modelDownloadManager.downloadSources, + ) { states, sources -> + _uiState.update { current -> + current.copy( + e2bAvailability = states[KernelModel.GEMMA_4_E2B] + ?.toAvailability( + KernelModel.GEMMA_4_E2B, + hfAuth = false, + source = sources[KernelModel.GEMMA_4_E2B] + ?: DownloadSource.USER_INITIATED, + ), + e4bAvailability = states[KernelModel.GEMMA_4_E4B] + ?.toAvailability( + KernelModel.GEMMA_4_E4B, + hfAuth = false, + source = sources[KernelModel.GEMMA_4_E4B] + ?: DownloadSource.USER_INITIATED, + ), + ) + } + }.collect() + } } private fun loadSettings() { diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsScreen.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsScreen.kt index ac82b3754..35f8de12e 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsScreen.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsScreen.kt @@ -149,19 +149,17 @@ fun SettingsScreen( ) HorizontalDivider() - // HuggingFace account grouped with models — needed to unlock gated HF models - Text( - text = "HuggingFace Account", - style = MaterialTheme.typography.labelMedium, - color = MaterialTheme.colorScheme.primary, - modifier = Modifier.padding(horizontal = 16.dp, vertical = 4.dp), - ) - HuggingFaceAccountRow( - isAuthenticated = uiState.hfAuthenticated, - username = uiState.hfUsername, - onSignIn = { viewModel.startAuth() }, - onSignOut = { viewModel.signOutHuggingFace() }, - onViewLicence = { openInAppBrowser(context, "https://huggingface.co/litert-community/embeddinggemma-300m") }, + // ── Model availability (tappable) ──────────────────────────────── + ListItem( + modifier = Modifier + .fillMaxWidth() + .clickable { onNavigateToModelManagement(false) }, + headlineContent = { Text("Model availability") }, + supportingContent = { + Text(uiState.modelAvailabilitySummary.displaySummary) + }, + leadingContent = { Icon(Icons.Default.Download, contentDescription = null) }, + trailingContent = { Icon(Icons.Default.ChevronRight, contentDescription = null) }, ) HorizontalDivider() diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsViewModel.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsViewModel.kt index 89a482aae..f6ccb8aef 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsViewModel.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/SettingsViewModel.kt @@ -7,17 +7,22 @@ import com.kernel.ai.core.inference.auth.HuggingFaceAuthRepository import com.kernel.ai.core.inference.download.DownloadState import com.kernel.ai.core.inference.download.KernelModel import com.kernel.ai.core.inference.download.ModelDownloadManager +import com.kernel.ai.core.model.availability.AvailabilitySummary +import com.kernel.ai.core.model.availability.GatedModelStatus +import com.kernel.ai.core.model.availability.GatedModelStatusRepository +import com.kernel.ai.core.model.availability.computeAvailabilitySummary import com.kernel.ai.core.inference.hardware.HardwareProfileDetector import com.kernel.ai.core.inference.prefs.ModelPreferences import dagger.hilt.android.lifecycle.HiltViewModel -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asSharedFlow import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import java.io.IOException import javax.inject.Inject @@ -28,6 +33,7 @@ class SettingsViewModel @Inject constructor( private val modelDownloadManager: ModelDownloadManager, private val modelPreferences: ModelPreferences, private val authRepository: HuggingFaceAuthRepository, + private val gatedModelStatusRepository: GatedModelStatusRepository, ) : ViewModel() { data class SettingsUiState( @@ -41,14 +47,52 @@ class SettingsViewModel @Inject constructor( val hfAuthenticated: Boolean = false, /** HuggingFace username from OIDC id_token, or null. */ val hfUsername: String? = null, + val modelAvailabilitySummary: AvailabilitySummary = AvailabilitySummary(total = 0), ) + private val _gatedStatuses = MutableStateFlow>(emptyMap()) + + init { + viewModelScope.launch { + val gatedModels = KernelModel.entries.filter { + it.showInModelManagement && !it.isDeprecated && it.isGated + } + gatedModels.forEach { model -> + launch { + gatedModelStatusRepository.get(model).collect { status -> + _gatedStatuses.update { it.toMutableMap().apply { put(model, status) } } + } + } + } + } + // Forward authResult outcomes so the Settings screen can surface sign-in feedback. + viewModelScope.launch { + authRepository.authResult.collect { result -> + result.onSuccess { _saveSuccess.tryEmit("Signed in to HuggingFace ✓") } + result.onFailure { e -> _saveError.tryEmit("Sign-in failed: ${e.message}") } + } + } + } + val uiState: StateFlow = combine( modelPreferences.preferredConversationModel, modelDownloadManager.downloadStates, + modelDownloadManager.downloadSources, authRepository.isAuthenticated, authRepository.username, - ) { preferredModel, downloadStates, hfAuthenticated, hfUsername -> + _gatedStatuses, + ) { array -> + @Suppress("UNCHECKED_CAST") + val preferredModel = array[0] as KernelModel? + @Suppress("UNCHECKED_CAST") + val downloadStates = array[1] as Map + @Suppress("UNCHECKED_CAST") + val downloadSources = array[2] as Map + val hfAuthenticated = array[3] as Boolean + @Suppress("UNCHECKED_CAST") + val hfUsername = array[4] as String? + @Suppress("UNCHECKED_CAST") + val gatedStatuses = array[5] as Map val profile = hardwareProfileDetector.profile val e4bDownloaded = downloadStates[KernelModel.GEMMA_4_E4B] is DownloadState.Downloaded val e2bDownloaded = downloadStates[KernelModel.GEMMA_4_E2B] is DownloadState.Downloaded @@ -64,6 +108,14 @@ class SettingsViewModel @Inject constructor( } } + val summary = computeAvailabilitySummary( + models = KernelModel.entries.filter { it.showInModelManagement && !it.isDeprecated }, + downloadStates = downloadStates, + hfAuth = hfAuthenticated, + downloadSources = downloadSources, + gatedStatuses = gatedStatuses, + ) + SettingsUiState( activeModelLabel = activeModel.displayName, activeBackend = profile.recommendedBackend.name, @@ -73,6 +125,7 @@ class SettingsViewModel @Inject constructor( e4bDownloaded = e4bDownloaded, hfAuthenticated = hfAuthenticated, hfUsername = hfUsername, + modelAvailabilitySummary = summary, ) }.stateIn( scope = viewModelScope, @@ -86,16 +139,6 @@ class SettingsViewModel @Inject constructor( private val _saveSuccess = MutableSharedFlow(extraBufferCapacity = 1) val saveSuccess: SharedFlow = _saveSuccess.asSharedFlow() - init { - // Forward authResult outcomes so the Settings screen can surface sign-in feedback. - viewModelScope.launch { - authRepository.authResult.collect { result -> - result.onSuccess { _saveSuccess.tryEmit("Signed in to HuggingFace ✓") } - result.onFailure { e -> _saveError.tryEmit("Sign-in failed: ${e.message}") } - } - } - } - fun setPreferredModel(model: KernelModel?) { viewModelScope.launch { val current = uiState.value.preferredModel diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceScreen.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceScreen.kt index 25dfd4f6b..84382412f 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceScreen.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceScreen.kt @@ -28,7 +28,6 @@ import androidx.compose.material3.FilterChip import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.Icon import androidx.compose.material3.IconButton -import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.ListItem import androidx.compose.material3.MaterialTheme import androidx.compose.material3.OutlinedTextField @@ -39,6 +38,8 @@ import androidx.compose.material3.Switch import androidx.compose.material3.Tab import androidx.compose.material3.TabRow import androidx.compose.material3.Text +import androidx.compose.material3.Button +import androidx.compose.material3.OutlinedButton import androidx.compose.material3.TextButton import androidx.compose.material3.TopAppBar import androidx.compose.runtime.Composable @@ -69,7 +70,10 @@ import com.kernel.ai.core.voice.VctkSpeakerMetadata import com.kernel.ai.core.voice.VoiceInputEngine import com.kernel.ai.core.voice.VoiceOutputEngine import com.kernel.ai.core.voice.VoicePackDownloadState +import com.kernel.ai.core.model.availability.ModelAvailabilityState +import com.kernel.ai.core.model.availability.ModelCardCompact import kotlin.math.roundToInt +import com.kernel.ai.core.model.availability.UnavailableReason import android.content.Intent import android.os.Build import android.provider.Settings @@ -90,6 +94,7 @@ import androidx.lifecycle.compose.LocalLifecycleOwner @Composable fun VoiceScreen( onBack: () -> Unit, + onNavigateToModelManagement: () -> Unit = {}, viewModel: VoiceViewModel = hiltViewModel(), ) { val context = LocalContext.current @@ -181,19 +186,19 @@ fun VoiceScreen( onSherpaGainChanged = viewModel::setSherpaGain, onAutoSpeakChanged = viewModel::setAutoSpeak, onMaxSpokenSentencesChanged = viewModel::setMaxSpokenSentences, - onDownloadVoice = viewModel::downloadSherpaVoice, - onCancelVoiceDownload = viewModel::cancelSherpaVoiceDownload, - onDeleteVoice = viewModel::deleteSherpaVoice, onActiveSpeakerIdChanged = viewModel::setActiveSpeakerId, onKokoroVoiceSelected = viewModel::setKokoroVoice, - onDownloadKokoroVoice = viewModel::downloadKokoroVoice, - onCancelKokoroVoiceDownload = viewModel::cancelKokoroVoiceDownload, - onDeleteKokoroVoice = viewModel::deleteKokoroVoice, onKokoroActiveSpeakerIdChanged = viewModel::setKokoroActiveSpeakerId, + onNavigateToModelManagement = onNavigateToModelManagement, onDownloadSherpaStt = viewModel::downloadSherpaStt, - onCancelSherpaSttDownload = viewModel::cancelSherpaSttDownload, + onCancelSherpaStt = viewModel::cancelSherpaSttDownload, onDeleteSherpaStt = viewModel::deleteSherpaStt, - onViewSherpaSttLicence = { url -> openInAppBrowser(context, url) }, + onDownloadSherpaVoice = viewModel::downloadSherpaVoice, + onCancelSherpaVoice = viewModel::cancelSherpaVoiceDownload, + onDeleteSherpaVoice = viewModel::deleteSherpaVoice, + onDownloadKokoroVoice = viewModel::downloadKokoroVoice, + onCancelKokoroVoice = viewModel::cancelKokoroVoiceDownload, + onDeleteKokoroVoice = viewModel::deleteKokoroVoice, ) } @@ -215,19 +220,19 @@ private fun VoiceScreenContent( onSherpaGainChanged: (Float) -> Unit, onAutoSpeakChanged: (Boolean) -> Unit, onMaxSpokenSentencesChanged: (Int) -> Unit, - onDownloadVoice: (SherpaPiperVoice) -> Unit, - onCancelVoiceDownload: (SherpaPiperVoice) -> Unit, - onDeleteVoice: (SherpaPiperVoice) -> Unit, onActiveSpeakerIdChanged: (Int) -> Unit, onKokoroVoiceSelected: (SherpaKokoroVoice) -> Unit, - onDownloadKokoroVoice: (SherpaKokoroVoice) -> Unit, - onCancelKokoroVoiceDownload: (SherpaKokoroVoice) -> Unit, - onDeleteKokoroVoice: (SherpaKokoroVoice) -> Unit, onKokoroActiveSpeakerIdChanged: (Int) -> Unit, - onDownloadSherpaStt: (VoiceInputEngine) -> Unit, - onCancelSherpaSttDownload: (VoiceInputEngine) -> Unit, - onDeleteSherpaStt: (VoiceInputEngine) -> Unit, - onViewSherpaSttLicence: (String) -> Unit, + onNavigateToModelManagement: () -> Unit, + onDownloadSherpaStt: (VoiceInputEngine) -> Unit = { _ -> }, + onCancelSherpaStt: (VoiceInputEngine) -> Unit = { _ -> }, + onDeleteSherpaStt: (VoiceInputEngine) -> Unit = { _ -> }, + onDownloadSherpaVoice: (SherpaPiperVoice) -> Unit = { _ -> }, + onCancelSherpaVoice: (SherpaPiperVoice) -> Unit = { _ -> }, + onDeleteSherpaVoice: (SherpaPiperVoice) -> Unit = { _ -> }, + onDownloadKokoroVoice: (SherpaKokoroVoice) -> Unit = { _ -> }, + onCancelKokoroVoice: (SherpaKokoroVoice) -> Unit = { _ -> }, + onDeleteKokoroVoice: (SherpaKokoroVoice) -> Unit = { _ -> }, ) { val context = LocalContext.current Scaffold( @@ -414,24 +419,56 @@ private fun VoiceScreenContent( if (engine.isSherpaFamily) { val state = sttState if (state != null && (!state.isDownloaded || uiState.selectedInputEngine == engine)) { - SherpaOnnxSttDownloadCard( - isDownloaded = state.isDownloaded, - isDownloading = state.isDownloading, - progress = state.progress, - issue = state.issue, - modelSubtitle = when (engine) { + ModelCardCompact( + title = engine.displayName, + description = when (engine) { VoiceInputEngine.SherpaZipformer -> SherpaSttModelSpec.ZIPFORMER.subtitle VoiceInputEngine.SherpaSenseVoice -> SherpaSttModelSpec.SENSE_VOICE.subtitle VoiceInputEngine.SherpaWhisper -> SherpaSttModelSpec.WHISPER.subtitle VoiceInputEngine.SherpaParaformer -> SherpaSttModelSpec.PARAFORMER.subtitle else -> "" }, - onDownload = { onDownloadSherpaStt(engine) }, - onCancel = { onCancelSherpaSttDownload(engine) }, - onDelete = { onDeleteSherpaStt(engine) }, - onViewLicence = onViewSherpaSttLicence, + state = uiState.sherpaSttAvailability[engine] + ?: ModelAvailabilityState.Unavailable(UnavailableReason.NotBundled), modifier = Modifier.padding(horizontal = 16.dp, vertical = 4.dp), ) + val sttState = state + if (sttState != null) { + when { + sttState.isDownloading -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onCancelSherpaStt(engine) }, + ) { + Text("Cancel") + } + } + sttState.isDownloaded -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDeleteSherpaStt(engine) }, + ) { + Text("Delete") + } + } + sttState.issue != null -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDownloadSherpaStt(engine) }, + ) { + Text("Retry") + } + } + else -> { + Button( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDownloadSherpaStt(engine) }, + ) { + Text("Download") + } + } + } + } } } } @@ -674,16 +711,59 @@ private fun VoiceScreenContent( val isMultiSpeaker = voiceRow.voice == SherpaPiperVoice.VctkMedium || voiceRow.voice == SherpaPiperVoice.SemaineMedium - SherpaVoiceRow( - rowState = voiceRow, - isSelected = isSelected, - onSelect = { onSherpaVoiceSelected(voiceRow.voice) }, - onDownload = { onDownloadVoice(voiceRow.voice) }, - onCancel = { onCancelVoiceDownload(voiceRow.voice) }, - onDelete = { onDeleteVoice(voiceRow.voice) }, - ) + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + ) { + ModelCardCompact( + title = voiceRow.voice.displayName, + description = voiceRow.voice.description, + state = uiState.sherpaVoiceAvailability[voiceRow.voice] + ?: ModelAvailabilityState.Unavailable(UnavailableReason.NotBundled), + modifier = Modifier.weight(1f), + ) + RadioButton( + selected = isSelected, + onClick = { if (isDownloaded) onSherpaVoiceSelected(voiceRow.voice) }, + enabled = isDownloaded, + ) + } + val voiceRowState = voiceRow.downloadState + when (voiceRowState) { + is VoicePackDownloadState.NotDownloaded -> { + Button( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDownloadSherpaVoice(voiceRow.voice) }, + ) { + Text("Download") + } + } + is VoicePackDownloadState.Downloading -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onCancelSherpaVoice(voiceRow.voice) }, + ) { + Text("Cancel") + } + } + is VoicePackDownloadState.Downloaded -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDeleteSherpaVoice(voiceRow.voice) }, + ) { + Text("Delete") + } + } + is VoicePackDownloadState.Error -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDownloadSherpaVoice(voiceRow.voice) }, + ) { + Text("Retry") + } + } + } HorizontalDivider() - if (isSelected && isDownloaded && isMultiSpeaker) { Row( modifier = Modifier @@ -729,6 +809,12 @@ private fun VoiceScreenContent( } } } + TextButton( + onClick = onNavigateToModelManagement, + modifier = Modifier.padding(horizontal = 16.dp, vertical = 4.dp), + ) { + Text("Manage voice models") + } } if (uiState.selectedOutputEngine == VoiceOutputEngine.KokoroExperimental) { @@ -802,14 +888,58 @@ private fun VoiceScreenContent( val isSelected = uiState.selectedKokoroVoice == voiceRow.voice val isDownloaded = voiceRow.downloadState is VoicePackDownloadState.Downloaded - KokoroVoiceRow( - rowState = voiceRow, - isSelected = isSelected, - onSelect = { onKokoroVoiceSelected(voiceRow.voice) }, - onDownload = { onDownloadKokoroVoice(voiceRow.voice) }, - onCancel = { onCancelKokoroVoiceDownload(voiceRow.voice) }, - onDelete = { onDeleteKokoroVoice(voiceRow.voice) }, - ) + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + ) { + ModelCardCompact( + title = voiceRow.voice.displayName, + description = voiceRow.voice.description, + state = uiState.kokoroVoiceAvailability[voiceRow.voice] + ?: ModelAvailabilityState.Unavailable(UnavailableReason.NotBundled), + modifier = Modifier.weight(1f), + ) + RadioButton( + selected = isSelected, + onClick = { if (isDownloaded) onKokoroVoiceSelected(voiceRow.voice) }, + enabled = isDownloaded, + ) + } + val kokoroRowState = voiceRow.downloadState + when (kokoroRowState) { + is VoicePackDownloadState.NotDownloaded -> { + Button( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDownloadKokoroVoice(voiceRow.voice) }, + ) { + Text("Download") + } + } + is VoicePackDownloadState.Downloading -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onCancelKokoroVoice(voiceRow.voice) }, + ) { + Text("Cancel") + } + } + is VoicePackDownloadState.Downloaded -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDeleteKokoroVoice(voiceRow.voice) }, + ) { + Text("Delete") + } + } + is VoicePackDownloadState.Error -> { + OutlinedButton( + modifier = Modifier.padding(horizontal = 16.dp, vertical = 2.dp), + onClick = { onDownloadKokoroVoice(voiceRow.voice) }, + ) { + Text("Retry") + } + } + } HorizontalDivider() if (isSelected && isDownloaded && voiceRow.voice.speakerCount > 1) { @@ -820,6 +950,12 @@ private fun VoiceScreenContent( ) } } + TextButton( + onClick = onNavigateToModelManagement, + modifier = Modifier.padding(horizontal = 16.dp, vertical = 4.dp), + ) { + Text("Manage voice models") + } } } @@ -944,255 +1080,7 @@ private fun SemaineSpeakerSelector( } } -/** - * A single Sherpa Piper voice row with download/cancel/delete controls and progress indicator. - * Mirrors [ModelRow] in [ModelManagementScreen] for visual consistency. - */ -@Composable -private fun SherpaVoiceRow( - rowState: SherpaVoiceRowUiState, - isSelected: Boolean, - onSelect: () -> Unit, - onDownload: () -> Unit, - onCancel: () -> Unit, - onDelete: () -> Unit, - modifier: Modifier = Modifier, -) { - val voice = rowState.voice - val state = rowState.downloadState - val isDownloaded = state is VoicePackDownloadState.Downloaded - ListItem( - modifier = modifier.fillMaxWidth(), - headlineContent = { - Column(verticalArrangement = Arrangement.spacedBy(2.dp)) { - Text(voice.displayName) - Text( - text = when { - isDownloaded && isSelected -> "Selected voice" - isDownloaded -> "Downloaded and ready" - state is VoicePackDownloadState.Downloading -> "Downloading" - state is VoicePackDownloadState.Error -> "Download failed" - else -> "Not downloaded" - }, - style = MaterialTheme.typography.bodySmall, - color = when { - isDownloaded && isSelected -> MaterialTheme.colorScheme.primary - isDownloaded -> Color(0xFF2E7D32) - state is VoicePackDownloadState.Error -> MaterialTheme.colorScheme.error - else -> MaterialTheme.colorScheme.onSurfaceVariant - }, - ) - } - }, - supportingContent = { - Column { - Text( - text = voice.description, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - Text( - text = formatBytes(voice.approxDownloadBytes), - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - if (!isDownloaded) { - Text( - text = when (state) { - is VoicePackDownloadState.Downloading -> "Selectable after download completes" - is VoicePackDownloadState.Error -> "Retry download before selecting this voice" - else -> "Download to make this voice selectable" - }, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - } - when (state) { - is VoicePackDownloadState.Downloading -> { - Spacer(modifier = Modifier.height(4.dp)) - LinearProgressIndicator( - progress = { state.progress }, - modifier = Modifier.fillMaxWidth(), - ) - val pct = (state.progress * 100).toInt() - val mbps = state.bytesPerSecond / 1_000_000.0 - val etaSec = state.remainingMs / 1000 - Text( - text = buildString { - if (pct >= 90) append("Extracting…") - else { - append("$pct%") - if (state.bytesPerSecond > 0) append(" · ${"%.1f".format(mbps)} MB/s") - if (etaSec > 0) append(" · ${etaSec}s remaining") - } - }, - style = MaterialTheme.typography.bodySmall, - ) - } - is VoicePackDownloadState.Error -> { - Text( - text = state.message, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.error, - ) - } - else -> Unit - } - } - }, - trailingContent = { - Row(verticalAlignment = Alignment.CenterVertically) { - when (state) { - is VoicePackDownloadState.NotDownloaded, is VoicePackDownloadState.Error -> { - TextButton(onClick = onDownload) { Text("Download") } - } - is VoicePackDownloadState.Downloading -> { - TextButton(onClick = onCancel) { Text("Cancel") } - } - is VoicePackDownloadState.Downloaded -> { - Icon( - Icons.Default.CheckCircle, - contentDescription = "Downloaded", - tint = Color(0xFF4CAF50), - modifier = Modifier.size(20.dp), - ) - Spacer(modifier = Modifier.width(4.dp)) - RadioButton(selected = isSelected, onClick = onSelect) - TextButton(onClick = onDelete) { - Text("Delete", color = MaterialTheme.colorScheme.error) - } - } - } - } - }, - ) -} - -/** - * A single Kokoro voice row with download/cancel/delete controls and progress indicator. - * Mirrors [SherpaVoiceRow] for visual consistency. - */ -@Composable -private fun KokoroVoiceRow( - rowState: KokoroVoiceRowUiState, - isSelected: Boolean, - onSelect: () -> Unit, - onDownload: () -> Unit, - onCancel: () -> Unit, - onDelete: () -> Unit, - modifier: Modifier = Modifier, -) { - val voice = rowState.voice - val state = rowState.downloadState - val isDownloaded = state is VoicePackDownloadState.Downloaded - - ListItem( - modifier = modifier.fillMaxWidth(), - headlineContent = { - Column(verticalArrangement = Arrangement.spacedBy(2.dp)) { - Text(voice.displayName) - Text( - text = when { - isDownloaded && isSelected -> "Selected voice" - isDownloaded -> "Downloaded and ready" - state is VoicePackDownloadState.Downloading -> "Downloading" - state is VoicePackDownloadState.Error -> "Download failed" - else -> "Not downloaded" - }, - style = MaterialTheme.typography.bodySmall, - color = when { - isDownloaded && isSelected -> MaterialTheme.colorScheme.primary - isDownloaded -> Color(0xFF2E7D32) - state is VoicePackDownloadState.Error -> MaterialTheme.colorScheme.error - else -> MaterialTheme.colorScheme.onSurfaceVariant - }, - ) - } - }, - supportingContent = { - Column { - Text( - text = voice.description, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - Text( - text = formatBytes(voice.approxDownloadBytes), - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - if (!isDownloaded) { - Text( - text = when (state) { - is VoicePackDownloadState.Downloading -> "Selectable after download completes" - is VoicePackDownloadState.Error -> "Retry download before selecting this voice" - else -> "Download to make this voice selectable" - }, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - } - when (state) { - is VoicePackDownloadState.Downloading -> { - Spacer(modifier = Modifier.height(4.dp)) - LinearProgressIndicator( - progress = { state.progress }, - modifier = Modifier.fillMaxWidth(), - ) - val pct = (state.progress * 100).toInt() - val mbps = state.bytesPerSecond / 1_000_000.0 - val etaSec = state.remainingMs / 1000 - Text( - text = buildString { - if (pct >= 90) append("Extracting…") - else { - append("$pct%") - if (state.bytesPerSecond > 0) append(" · ${"%.1f".format(mbps)} MB/s") - if (etaSec > 0) append(" · ${etaSec}s remaining") - } - }, - style = MaterialTheme.typography.bodySmall, - ) - } - is VoicePackDownloadState.Error -> { - Text( - text = state.message, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.error, - ) - } - else -> Unit - } - } - }, - trailingContent = { - Row(verticalAlignment = Alignment.CenterVertically) { - when (state) { - is VoicePackDownloadState.NotDownloaded, is VoicePackDownloadState.Error -> { - TextButton(onClick = onDownload) { Text("Download") } - } - is VoicePackDownloadState.Downloading -> { - TextButton(onClick = onCancel) { Text("Cancel") } - } - is VoicePackDownloadState.Downloaded -> { - Icon( - Icons.Default.CheckCircle, - contentDescription = "Downloaded", - tint = Color(0xFF4CAF50), - modifier = Modifier.size(20.dp), - ) - Spacer(modifier = Modifier.width(4.dp)) - RadioButton(selected = isSelected, onClick = onSelect) - TextButton(onClick = onDelete) { - Text("Delete", color = MaterialTheme.colorScheme.error) - } - } - } - } - }, - ) -} /** * Speaker selector for Kokoro multi-speaker model (103 speakers, sid 0–102). @@ -1378,105 +1266,6 @@ private fun VoiceInfoCard( } -/** - * Inline card shown under a Sherpa-ONNX STT engine row when the engine is selected. - * Mirrors the pattern of [SherpaVoiceRow] / [KokoroVoiceRow] but groups the required - * model files as a single logical unit. - */ -@Composable -private fun SherpaOnnxSttDownloadCard( - isDownloaded: Boolean, - isDownloading: Boolean, - progress: Float, - issue: SherpaSttDownloadIssue?, - modelSubtitle: String, - onDownload: () -> Unit, - onCancel: () -> Unit, - onDelete: () -> Unit, - onViewLicence: (String) -> Unit, - modifier: Modifier = Modifier, -) { - Card( - modifier = modifier.fillMaxWidth(), - colors = CardDefaults.cardColors( - containerColor = if (isDownloaded) - MaterialTheme.colorScheme.primaryContainer - else - MaterialTheme.colorScheme.surfaceVariant, - ), - ) { - Column( - modifier = Modifier - .fillMaxWidth() - .padding(12.dp), - verticalArrangement = Arrangement.spacedBy(8.dp), - ) { - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.SpaceBetween, - verticalAlignment = Alignment.CenterVertically, - ) { - Column(modifier = Modifier.weight(1f)) { - Text( - text = if (isDownloaded) "STT model ready" else "STT model required", - style = MaterialTheme.typography.labelMedium, - color = if (isDownloaded) - MaterialTheme.colorScheme.onPrimaryContainer - else - MaterialTheme.colorScheme.onSurfaceVariant, - ) - if (!isDownloaded && !isDownloading) { - Text( - text = modelSubtitle, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - } - if (issue != null) { - Text( - text = issue.message, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.error, - ) - } - } - when { - isDownloaded -> Row( - horizontalArrangement = Arrangement.spacedBy(4.dp), - verticalAlignment = Alignment.CenterVertically, - ) { - Icon( - imageVector = Icons.Filled.CheckCircle, - contentDescription = null, - tint = MaterialTheme.colorScheme.primary, - ) - TextButton(onClick = onDelete) { Text("Delete") } - } - isDownloading -> TextButton(onClick = onCancel) { Text("Cancel") } - issue?.licenceRequired == true && issue.licenceUrl != null -> Row( - horizontalArrangement = Arrangement.spacedBy(4.dp), - verticalAlignment = Alignment.CenterVertically, - ) { - TextButton(onClick = { onViewLicence(issue.licenceUrl) }) { Text("Accept licence") } - TextButton(onClick = onDownload) { Text("Retry") } - } - else -> TextButton(onClick = onDownload) { Text("Download") } - } - } - if (isDownloading) { - LinearProgressIndicator( - progress = { progress }, - modifier = Modifier.fillMaxWidth(), - ) - Text( - text = "${(progress * 100).toInt()}%", - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - ) - } - } - } -} @Composable private fun VoiceOutputSelectionCard( @@ -1632,19 +1421,10 @@ private fun VoiceScreenPreview() { onSherpaGainChanged = {}, onAutoSpeakChanged = {}, onMaxSpokenSentencesChanged = {}, - onDownloadVoice = {}, - onCancelVoiceDownload = {}, - onDeleteVoice = {}, onActiveSpeakerIdChanged = {}, onKokoroVoiceSelected = {}, - onDownloadKokoroVoice = {}, - onCancelKokoroVoiceDownload = {}, - onDeleteKokoroVoice = {}, onKokoroActiveSpeakerIdChanged = {}, - onDownloadSherpaStt = {}, - onCancelSherpaSttDownload = {}, - onDeleteSherpaStt = {}, - onViewSherpaSttLicence = {}, + onNavigateToModelManagement = {}, ) } } diff --git a/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceViewModel.kt b/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceViewModel.kt index dc3b7474f..9697390b5 100644 --- a/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceViewModel.kt +++ b/feature/settings/src/main/java/com/kernel/ai/feature/settings/VoiceViewModel.kt @@ -32,6 +32,9 @@ import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import com.kernel.ai.core.model.availability.ActionReason +import com.kernel.ai.core.model.availability.ModelAvailabilityState +import com.kernel.ai.core.model.availability.UnavailableReason import javax.inject.Inject data class SherpaSttDownloadIssue( @@ -96,11 +99,47 @@ data class VoiceUiState( // ── Sherpa-ONNX STT model download states (per family) ────────────────── /** Per-family download state for each Sherpa STT engine. */ val sherpaSttStates: Map = emptyMap(), + // ── Model availability states (for ModelCardCompact) ──────────────────── + /** Per-voice availability state for Sherpa Piper voices. */ + val sherpaVoiceAvailability: Map = emptyMap(), + /** Per-voice availability state for Kokoro voices. */ + val kokoroVoiceAvailability: Map = emptyMap(), + /** Per-engine availability state for Sherpa STT models. */ + val sherpaSttAvailability: Map = emptyMap(), ) internal fun resolveAndroidNativeAvailabilityMessage( availability: AndroidNativeRecognitionAvailability, ): String? = availability.warningMessage +/** + * Maps a [VoicePackDownloadState] to the UI-layer [ModelAvailabilityState]. + * Used by [ModelCardCompact] in the voice screen. + */ +internal fun VoicePackDownloadState.toModelAvailability(): ModelAvailabilityState = when (this) { + is VoicePackDownloadState.Downloaded -> ModelAvailabilityState.Ready + is VoicePackDownloadState.Downloading -> ModelAvailabilityState.Preparing( + progress = progress, + isAutoQueued = false, + ) + is VoicePackDownloadState.Error -> ModelAvailabilityState.ActionRequired( + ActionReason.DownloadFailed(message) + ) + is VoicePackDownloadState.NotDownloaded -> ModelAvailabilityState.Unavailable( + UnavailableReason.NotBundled + ) +} + +/** + * Maps a [SherpaSttDownloadState] to the UI-layer [ModelAvailabilityState]. + */ +internal fun SherpaSttDownloadState.toModelAvailability(): ModelAvailabilityState = when { + isDownloaded -> ModelAvailabilityState.Ready + isDownloading -> ModelAvailabilityState.Preparing(progress = progress, isAutoQueued = false) + issue?.licenceRequired == true -> ModelAvailabilityState.ActionRequired(ActionReason.LicenseRequired) + issue != null -> ModelAvailabilityState.ActionRequired(ActionReason.DownloadFailed(issue.message)) + else -> ModelAvailabilityState.Unavailable(UnavailableReason.NotBundled) +} + @HiltViewModel class VoiceViewModel @Inject constructor( @@ -189,6 +228,9 @@ class VoiceViewModel @Inject constructor( }, isSelectedSherpaVoiceDownloaded = states[it.selectedSherpaVoice] is VoicePackDownloadState.Downloaded, + sherpaVoiceAvailability = SherpaPiperVoice.entries.associateWith { voice -> + (states[voice] ?: VoicePackDownloadState.NotDownloaded).toModelAvailability() + }, ) } } @@ -228,6 +270,9 @@ class VoiceViewModel @Inject constructor( kokoroVoices = rows, isSelectedKokoroVoiceDownloaded = states[state.selectedKokoroVoice] is VoicePackDownloadState.Downloaded, + kokoroVoiceAvailability = SherpaKokoroVoice.entries.associateWith { voice -> + (states[voice] ?: VoicePackDownloadState.NotDownloaded).toModelAvailability() + }, ) } } @@ -257,7 +302,14 @@ class VoiceViewModel @Inject constructor( val perFamilyStates = SherpaSttModelSpec.ALL.mapValues { (engine, spec) -> computeDownloadState(spec, states) } - _uiState.update { it.copy(sherpaSttStates = perFamilyStates) } + _uiState.update { + it.copy( + sherpaSttStates = perFamilyStates, + sherpaSttAvailability = perFamilyStates.mapValues { (_, state) -> + state.toModelAvailability() + }, + ) + } } } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 358fe7239..3353f1774 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -22,6 +22,7 @@ include(":core:memory") include(":core:voice") include(":core:wasm") include(":core:ui") +include(":core:model-availability") include(":core:skills") include(":feature:chat") include(":feature:convert")