Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,23 @@ Versioning follows [Semantic Versioning](https://semver.org/).

## [Unreleased]

(nothing yet)
### Added
- **Prompt cache tiering** (v0.4.0 engine parity, part 1 of 3).
Successive chat turns on the same model now reuse the KV cache
when the new prompt extends the previous one — the shared prefix
skips prefill. In-memory hot tier (LRU, 8 entries in MVP) backed
by on-disk cold tier at `~/.mac-mlx/kv-cache/`, 16-way sharded
safetensors round-tripped through mlx-swift-lm's `savePromptCache`
/ `loadPromptCache`. Coding-assistant workflows (Claude Code,
Cursor, Zed re-sending conversation history each turn) see
reduced time-to-first-token on repeat prefixes.
- Settings → "KV Cache" section with hot/cold budget steppers and
a "Clear All KV Caches" button. Steppers currently inform future
byte-accurate budgeting (v0.4.0.1) — today's enforcement is the
8-entry hot LRU cap plus manual Clear.
- Debug-level Logs tab entries `Prompt cache HIT — restored N
tokens` / `Prompt cache MISS — cold prefill of N tokens` under
the `engine` category, so you can see cache effectiveness.

---

Expand Down
136 changes: 122 additions & 14 deletions MacMLXCore/Sources/MacMLXCore/Engine/MLXSwiftEngine.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import Foundation
import MLX
import MLXLLM
import MLXLMCommon
@preconcurrency import Tokenizers

// MARK: - Sendable-box helpers

/// Lightweight unchecked-Sendable wrapper used to pass non-Sendable
/// mlx-swift-lm values (`LMInput`, `AsyncStream<TokenGeneration>`) across
/// isolation boundaries when we know the handoff is safe — we `consume`
/// them into the actor via `ModelContainer.perform(nonSendable:_:)` and
/// the actor owns them exclusively afterwards.
private struct NonSendableBox<T>: @unchecked Sendable {
let value: T
init(_ value: T) { self.value = value }
}

// MARK: - Tokenizer loader

/// Concrete TokenizerLoader that uses the HuggingFace swift-transformers library.
Expand Down Expand Up @@ -88,9 +101,18 @@ public actor MLXSwiftEngine: InferenceEngine {

private var modelContainer: ModelContainer?

/// Two-tier prompt cache (hot dict + cold safetensors sidecar). Used
/// by `runGeneration` to reuse KV state across successive turns on
/// the same model. See `PromptCacheStore` for the tiering policy.
private let promptCacheStore: PromptCacheStore

// MARK: Initialiser

public init() {}
public init() {
self.promptCacheStore = PromptCacheStore(
root: DataRoot.macMLX("kv-cache")
)
}

// MARK: InferenceEngine

Expand Down Expand Up @@ -205,9 +227,30 @@ public actor MLXSwiftEngine: InferenceEngine {
true
}

// MARK: Prompt cache management

/// Drop both tiers of the prompt cache. Wired up to the Settings
/// → "Clear All KV Caches" button via `EngineCoordinator`.
public func clearPromptCache() async {
await promptCacheStore.clearAll()
}

// MARK: Private generation helper

/// Actor-isolated generation driver called from within `generate(_:)`.
///
/// Flow:
/// 1. Prepare the `LMInput` (tokenisation + chat template application).
/// 2. Hash the full input-token sequence into a `PromptCacheKey`.
/// 3. Look up a prior cache snapshot in `promptCacheStore`. On hit,
/// reuse its `[KVCache]` so the shared prefix skips prefill. On
/// miss, allocate a fresh cache via `model.newCache(...)`.
/// 4. Drive the low-level `generateTokens(input:cache:...)` call so
/// we see raw token IDs and can build the extended key
/// `inputTokens + generatedTokenIDs` after the stream ends.
/// 5. The `KVCache` protocol is class-bound — the same reference we
/// passed in is mutated in-place during generation, so at the
/// end we can save that same reference under the extended key.
private func runGeneration(
_ request: GenerateRequest,
into continuation: AsyncThrowingStream<GenerateChunk, Error>.Continuation
Expand All @@ -216,6 +259,10 @@ public actor MLXSwiftEngine: InferenceEngine {
continuation.finish(throwing: EngineError.modelNotLoaded)
return
}
guard let loadedModelSnapshot = loadedModel else {
continuation.finish(throwing: EngineError.modelNotLoaded)
return
}

let params = request.parameters

Expand Down Expand Up @@ -261,28 +308,89 @@ public actor MLXSwiftEngine: InferenceEngine {
throw EngineError.modelLoadFailed(reason: error.localizedDescription)
}

// Generate and stream chunks.
let stream = try await container.generate(input: lmInput, parameters: generateParams)
// Flat Int token array for key construction. `LMInput.text.tokens`
// is an `MLXArray`; `asArray(Int.self)` materialises to Swift.
let inputTokens = lmInput.text.tokens.asArray(Int.self)
let modelID = loadedModelSnapshot.id
let priorKey = PromptCacheKey(modelID: modelID, tokens: inputTokens)

// Try the store. On hit we reuse the restored cache; on miss we
// let the iterator allocate a fresh one inside `generateTokens`.
let priorSnapshot = await promptCacheStore.get(priorKey)
let priorCache: [any KVCache]?
if let snapshot = priorSnapshot {
priorCache = snapshot.caches
await LogManager.shared.debug(
"Prompt cache HIT — restored \(priorKey.tokenCount) tokens (model=\(modelID))",
category: .inference
)
} else {
priorCache = nil
await LogManager.shared.debug(
"Prompt cache MISS — cold prefill of \(priorKey.tokenCount) tokens (model=\(modelID))",
category: .inference
)
}

// Build the working cache. When we have a prior snapshot we pass
// that reference straight through; otherwise we ask the model to
// allocate a fresh `[KVCache]`. We hold onto the same array so we
// can save it after generation (KVCache is class-bound, so the
// iterator populates our instances in place).
//
// `KVCache` is not `Sendable`, and `LMInput` is not `Sendable`
// either. Route both through the `perform(nonSendable:_:)`
// overload on `ModelContainer`, which explicitly accepts a
// non-Sendable value by `consuming` it into the actor.
let tokenizer = await container.tokenizer
let priorCacheBox: PromptCacheSnapshot? = priorCache.map { PromptCacheSnapshot($0) }
let inputBox = NonSendableBox(lmInput)

let setup: (cache: PromptCacheSnapshot, stream: AsyncStream<TokenGeneration>) =
try await container.perform(nonSendable: inputBox) { context, inputBox in
let cache: [any KVCache] = priorCacheBox?.caches
?? context.model.newCache(parameters: generateParams)
let stream = try MLXLMCommon.generateTokens(
input: inputBox.value,
cache: cache,
parameters: generateParams,
context: context
)
return (PromptCacheSnapshot(cache), stream)
}
let workingCache = setup.cache.caches
let stream = setup.stream

var detokenizer = NaiveStreamingDetokenizer(tokenizer: tokenizer)
var generatedTokenIDs: [Int] = []
var completionInfo: GenerateCompletionInfo?

for await generation in stream {
switch generation {
case .chunk(let text):
let chunk = GenerateChunk(text: text)
if case .terminated = continuation.yield(chunk) {
return
for await event in stream {
switch event {
case .token(let token):
generatedTokenIDs.append(token)
detokenizer.append(token: token)
if let piece = detokenizer.next() {
let chunk = GenerateChunk(text: piece)
if case .terminated = continuation.yield(chunk) {
return
}
}
case .info(let info):
completionInfo = info
case .toolCall:
// Tool calls not supported yet — out of scope through v0.3.
// Re-visit when there's a concrete tool-use feature to
// wire into (e.g. OpenAI-compatible function-calling).
break
}
}

// Save the post-generation cache under the extended key. The
// same `workingCache` reference has been mutated in-place by the
// iterator, so it now reflects prompt + generated tokens.
let finalTokens = inputTokens + generatedTokenIDs
let newKey = PromptCacheKey(modelID: modelID, tokens: finalTokens)
await promptCacheStore.put(
key: newKey,
snapshot: PromptCacheSnapshot(workingCache)
)

// Emit the final chunk with usage + finish reason.
if let info = completionInfo {
let finishReason: FinishReason
Expand Down
65 changes: 63 additions & 2 deletions MacMLXCore/Sources/MacMLXCore/Managers/SettingsManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ public struct Settings: Codable, Equatable, Sendable {
/// this at a mirror like "https://hf-mirror.com" (#21).
public var hfEndpoint: String

/// Hot prompt-cache capacity in megabytes — in-memory only.
///
/// MVP note: `PromptCacheStore`'s `hotCapacity` is an *entry* count,
/// not a byte budget. We persist the MB value for forward-compat so
/// a byte-accurate budget can land in v0.4.0.1 without a settings
/// migration. Today the engine ignores this value and uses the
/// default 8-entry cap.
public var kvCacheHotMB: Int

/// Cold prompt-cache disk cap in gigabytes.
///
/// MVP note: automatic cold-tier pruning is not yet implemented —
/// rely on Settings → "Clear All KV Caches" to reclaim space. Real
/// enforcement lands in v0.4.0.1.
public var kvCacheColdGB: Int

// MARK: Factory

/// Sensible out-of-the-box defaults — used when no settings file exists.
Expand All @@ -63,7 +79,9 @@ public struct Settings: Codable, Equatable, Sendable {
swiftLMPath: nil,
sparkleUpdateChannel: "release",
logRetentionDays: 7,
hfEndpoint: "https://huggingface.co"
hfEndpoint: "https://huggingface.co",
kvCacheHotMB: 512,
kvCacheColdGB: 20
)

// MARK: Init
Expand All @@ -79,7 +97,9 @@ public struct Settings: Codable, Equatable, Sendable {
swiftLMPath: String?,
sparkleUpdateChannel: String,
logRetentionDays: Int,
hfEndpoint: String = "https://huggingface.co"
hfEndpoint: String = "https://huggingface.co",
kvCacheHotMB: Int = 512,
kvCacheColdGB: Int = 20
) {
self.modelDirectory = modelDirectory
self.preferredEngine = preferredEngine
Expand All @@ -92,6 +112,47 @@ public struct Settings: Codable, Equatable, Sendable {
self.sparkleUpdateChannel = sparkleUpdateChannel
self.logRetentionDays = logRetentionDays
self.hfEndpoint = hfEndpoint
self.kvCacheHotMB = kvCacheHotMB
self.kvCacheColdGB = kvCacheColdGB
}

// MARK: - Codable (backward-compat decode)

/// Pre-v0.4 settings files don't have `kvCacheHotMB` /
/// `kvCacheColdGB` — decode them as optionals and fall back to the
/// defaults so existing installs keep working across upgrades.
private enum CodingKeys: String, CodingKey {
case modelDirectory
case preferredEngine
case serverPort
case autoStartServer
case lastLoadedModel
case onboardingComplete
case pythonPath
case swiftLMPath
case sparkleUpdateChannel
case logRetentionDays
case hfEndpoint
case kvCacheHotMB
case kvCacheColdGB
}

public init(from decoder: Decoder) throws {
let c = try decoder.container(keyedBy: CodingKeys.self)
self.modelDirectory = try c.decode(URL.self, forKey: .modelDirectory)
self.preferredEngine = try c.decode(EngineID.self, forKey: .preferredEngine)
self.serverPort = try c.decode(Int.self, forKey: .serverPort)
self.autoStartServer = try c.decode(Bool.self, forKey: .autoStartServer)
self.lastLoadedModel = try c.decodeIfPresent(String.self, forKey: .lastLoadedModel)
self.onboardingComplete = try c.decode(Bool.self, forKey: .onboardingComplete)
self.pythonPath = try c.decodeIfPresent(String.self, forKey: .pythonPath)
self.swiftLMPath = try c.decodeIfPresent(String.self, forKey: .swiftLMPath)
self.sparkleUpdateChannel = try c.decode(String.self, forKey: .sparkleUpdateChannel)
self.logRetentionDays = try c.decode(Int.self, forKey: .logRetentionDays)
self.hfEndpoint = try c.decodeIfPresent(String.self, forKey: .hfEndpoint)
?? "https://huggingface.co"
self.kvCacheHotMB = try c.decodeIfPresent(Int.self, forKey: .kvCacheHotMB) ?? 512
self.kvCacheColdGB = try c.decodeIfPresent(Int.self, forKey: .kvCacheColdGB) ?? 20
}
}

Expand Down
48 changes: 48 additions & 0 deletions MacMLXCore/Sources/MacMLXCore/PromptCache/PromptCacheKey.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import CryptoKit
import Foundation

/// Deterministic hash key identifying a cached KV-cache snapshot.
///
/// MVP hashes the entire token prefix. v0.4.1+ will switch to a
/// vLLM-style chained block hash (256 tokens per block + parent
/// hash) to enable longest-common-prefix matching across siblings;
/// today two requests have to share the EXACT same prefix to
/// benefit from the cache.
public struct PromptCacheKey: Hashable, Sendable {
public let modelID: String
public let tokenCount: Int
public let hashString: String

public init(modelID: String, tokens: [Int]) {
self.modelID = modelID
self.tokenCount = tokens.count
self.hashString = Self.hash(modelID: modelID, tokens: tokens)
}

/// SHA-256 over `(modelID, tokens)`. Tokens encoded as
/// little-endian Int32 for cross-platform stability.
private static func hash(modelID: String, tokens: [Int]) -> String {
var hasher = SHA256()
if let modelBytes = modelID.data(using: .utf8) {
hasher.update(data: modelBytes)
}
hasher.update(data: Data([0x00])) // separator
var buf = Data(capacity: tokens.count * 4)
for tok in tokens {
var v = Int32(tok).littleEndian
withUnsafeBytes(of: &v) { buf.append(contentsOf: $0) }
}
hasher.update(data: buf)
return hasher.finalize().map { String(format: "%02x", $0) }.joined()
}

/// `<root>/<shardChar>/<fullHash>.safetensors`. 16-way fanout
/// keeps any single directory from getting huge when the cold
/// store grows. `shardChar` is the first hex char of the hash.
public func shardedFileURL(under root: URL) -> URL {
let shard = String(hashString.prefix(1))
return root
.appending(path: shard, directoryHint: .isDirectory)
.appending(path: "\(hashString).safetensors", directoryHint: .notDirectory)
}
}
Loading