-
-
Notifications
You must be signed in to change notification settings - Fork 27
ADFA-4388 | embedding model crash #1429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: stage
Are you sure you want to change the base?
Changes from all commits
77f0700
e5fb835
7ef3e80
8200a6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,29 @@ class LlmInferenceEngine( | |
| private const val CONTEXT_SIZE_MID_MEM = 2048 | ||
| private const val CONTEXT_SIZE_HIGH_MEM = 3072 | ||
| private const val CONTEXT_SIZE_MAX = 4096 | ||
|
|
||
| private const val EXT_ONNX = ".onnx" | ||
| private const val EXT_PT = ".pt" | ||
| private const val EXT_PTH = ".pth" | ||
| private const val EXT_BIN = ".bin" | ||
| private const val EXT_SAFETENSORS = ".safetensors" | ||
| private const val EXT_PB = ".pb" | ||
| private const val EXT_TFLITE = ".tflite" | ||
| private const val EXT_GGML = ".ggml" | ||
| private const val EXT_GGUF = ".gguf" | ||
|
|
||
| private const val KEYWORD_TENSORFLOW = "tensorflow" | ||
| private const val KEYWORD_ALL_MINI = "all-mini" | ||
| private const val KEYWORD_ALL_MPNET = "all-mpnet" | ||
| private const val KEYWORD_E5 = "e5-" | ||
| private const val KEYWORD_EMBED = "embed" | ||
| private const val KEYWORD_LLAMA = "llama" | ||
| private const val KEYWORD_H2O = "h2o" | ||
| private const val KEYWORD_DANUBE = "danube" | ||
| private const val KEYWORD_QWEN = "qwen" | ||
| private const val KEYWORD_GEMMA3 = "gemma3" | ||
| private const val KEYWORD_GEMMA_3 = "gemma-3" | ||
| private const val KEYWORD_GEMMA = "gemma" | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -291,9 +314,12 @@ class LlmInferenceEngine( | |
| modelUriString: String, | ||
| expectedSha256: String? | ||
| ): Boolean { | ||
| val modelUri = modelUriString.toUri() | ||
| val displayName = resolveModelDisplayName(context, modelUri) | ||
|
|
||
| return try { | ||
| val modelUri = modelUriString.toUri() | ||
| val displayName = resolveModelDisplayName(context, modelUri) | ||
| validateModelFormat(displayName) | ||
|
|
||
| val destinationFile = File(context.cacheDir, "local_model.gguf") | ||
|
|
||
| if (!copyModelToCache(context, modelUri, destinationFile)) { | ||
|
|
@@ -313,6 +339,23 @@ class LlmInferenceEngine( | |
| currentModelFamily = detectModelFamily(displayName) | ||
| log.info("Successfully loaded local model: {}", loadedModelName) | ||
| true | ||
| } catch (e: IllegalStateException) { | ||
| if (e.message?.contains("embedding model") == true) { | ||
| log.error("Cannot use embedding model for chat: {}", displayName, e) | ||
| throw IllegalArgumentException( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uncaught exception reintroduces the crash this PR fixes (on other load paths). After this change Only
Scenario: a saved local-model path that resolves to an embedding/unsupported model (e.g. saved by a pre-fix build, or a Fix: either wrap those three call sites in try/catch and surface the message, or have |
||
| "The selected model '$displayName' is an embedding model designed for semantic " + | ||
| "search and similarity tasks. It cannot be used for chat or text generation.\n\n" + | ||
| "Please select a chat/instruct model instead (e.g., models with 'chat', 'instruct', " + | ||
| "'conversational' in their name).", e | ||
| ) | ||
| } else { | ||
| log.error("Failed to load model", e) | ||
| throw e | ||
| } | ||
| } catch (e: IllegalArgumentException) { | ||
| log.error("Model validation failed: {}", displayName, e) | ||
| resetLoadedModelState() | ||
| throw e | ||
| } catch (e: Exception) { | ||
| log.error("Failed to initialize or load model from file", e) | ||
| resetLoadedModelState() | ||
|
|
@@ -458,14 +501,92 @@ class LlmInferenceEngine( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Validates that the model file format is supported. | ||
| * This app uses llama.cpp which only supports GGUF format. | ||
| * | ||
| * @throws IllegalArgumentException if the model format is not supported | ||
| */ | ||
| private fun validateModelFormat(filename: String) { | ||
| val lowerName = filename.lowercase() | ||
|
|
||
| when { | ||
| lowerName.endsWith(EXT_ONNX) -> { | ||
| throw IllegalArgumentException( | ||
| "ONNX models ($EXT_ONNX) are not supported.\n\n" + | ||
| "This app uses llama.cpp which only supports GGUF format ($EXT_GGUF).\n\n" + | ||
| "To use this model:\n" + | ||
| "1. Convert it to GGUF format using llama.cpp conversion tools\n" + | ||
| "2. Or download a pre-converted GGUF version from Hugging Face" | ||
| ) | ||
| } | ||
| lowerName.endsWith(EXT_PT) || lowerName.endsWith(EXT_PTH) || lowerName.endsWith(EXT_BIN) -> { | ||
| throw IllegalArgumentException( | ||
| "PyTorch models ($EXT_PT, $EXT_PTH, $EXT_BIN) are not supported.\n\n" + | ||
| "This app uses llama.cpp which only supports GGUF format ($EXT_GGUF).\n\n" + | ||
| "To use this model:\n" + | ||
| "1. Convert it to GGUF format using convert_hf_to_gguf.py\n" + | ||
| "2. Or download a pre-converted GGUF version from Hugging Face" | ||
| ) | ||
| } | ||
| lowerName.endsWith(EXT_SAFETENSORS) -> { | ||
| throw IllegalArgumentException( | ||
| "SafeTensors models ($EXT_SAFETENSORS) are not directly supported.\n\n" + | ||
| "This app uses llama.cpp which only supports GGUF format ($EXT_GGUF).\n\n" + | ||
| "To use this model:\n" + | ||
| "1. Convert it to GGUF format using convert_hf_to_gguf.py\n" + | ||
| "2. Or download a pre-converted GGUF version from Hugging Face" | ||
| ) | ||
| } | ||
| lowerName.endsWith(EXT_PB) || lowerName.contains(KEYWORD_TENSORFLOW) -> { | ||
| throw IllegalArgumentException( | ||
| "TensorFlow models ($EXT_PB) are not supported.\n\n" + | ||
| "This app uses llama.cpp which only supports GGUF format ($EXT_GGUF).\n\n" + | ||
| "To use this model:\n" + | ||
| "1. Convert it to GGUF format using appropriate conversion tools\n" + | ||
| "2. Or download a pre-converted GGUF version from Hugging Face" | ||
| ) | ||
| } | ||
| lowerName.endsWith(EXT_TFLITE) -> { | ||
| throw IllegalArgumentException( | ||
| "TensorFlow Lite models ($EXT_TFLITE) are not supported.\n\n" + | ||
| "This app uses llama.cpp which only supports GGUF format ($EXT_GGUF).\n\n" + | ||
| "Please select a GGUF format model." | ||
| ) | ||
| } | ||
| lowerName.endsWith(EXT_GGML) -> { | ||
| throw IllegalArgumentException( | ||
| "GGML models ($EXT_GGML) are deprecated.\n\n" + | ||
| "This app uses the newer GGUF format ($EXT_GGUF).\n\n" + | ||
| "To use this model:\n" + | ||
| "1. Convert it to GGUF using convert_llama_ggml_to_gguf.py\n" + | ||
| "2. Or download a GGUF version from Hugging Face" | ||
| ) | ||
| } | ||
| !lowerName.endsWith(EXT_GGUF) -> { | ||
| log.warn("Model file '{}' doesn't have $EXT_GGUF extension. May fail to load.", filename) | ||
| } | ||
| } | ||
|
|
||
| if (lowerName.contains(KEYWORD_ALL_MINI) || | ||
| lowerName.contains(KEYWORD_ALL_MPNET) || | ||
| lowerName.contains(KEYWORD_E5) || | ||
| (lowerName.contains(KEYWORD_EMBED) && !lowerName.contains(KEYWORD_LLAMA))) { | ||
| log.warn( | ||
| "Model '{}' appears to be an embedding model based on filename. " + | ||
| "This may not work for chat. Will validate during load.", filename | ||
| ) | ||
| } | ||
| } | ||
|
|
||
| private fun detectModelFamily(path: String): ModelFamily { | ||
| val lowerPath = path.lowercase() | ||
| return when { | ||
| lowerPath.contains("h2o") || lowerPath.contains("danube") -> ModelFamily.H2O | ||
| lowerPath.contains("qwen") -> ModelFamily.QWEN | ||
| lowerPath.contains("gemma-3") || lowerPath.contains("gemma3") -> ModelFamily.GEMMA3 | ||
| lowerPath.contains("gemma") -> ModelFamily.GEMMA2 | ||
| lowerPath.contains("llama") -> ModelFamily.LLAMA3 | ||
| lowerPath.contains(KEYWORD_H2O) || lowerPath.contains(KEYWORD_DANUBE) -> ModelFamily.H2O | ||
| lowerPath.contains(KEYWORD_QWEN) -> ModelFamily.QWEN | ||
| lowerPath.contains(KEYWORD_GEMMA_3) || lowerPath.contains(KEYWORD_GEMMA3) -> ModelFamily.GEMMA3 | ||
| lowerPath.contains(KEYWORD_GEMMA) -> ModelFamily.GEMMA2 | ||
| lowerPath.contains(KEYWORD_LLAMA) -> ModelFamily.LLAMA3 | ||
| else -> ModelFamily.UNKNOWN | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make embedding-error detection case-insensitive.
The current message check can miss variants like
Embedding model, which bypasses the intendedIllegalArgumentExceptionmapping and user guidance.Suggested fix
📝 Committable suggestion
🤖 Prompt for AI Agents