Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ Versioning follows [Semantic Versioning](https://semver.org/).
- Debug-level Logs tab entries `Prompt cache HIT — restored N
tokens` / `Prompt cache MISS — cold prefill of N tokens` under
the `engine` category, so you can see cache effectiveness.
- **Multi-model pool** (v0.4.0 engine parity, part 2 of 3). Load
multiple models at once — previously the engine had to unload
the old model before loading a new one, which meant every API
cold-swap paid the full weight-read cost. Pool is bounded by a
user-configurable resident memory cap (Settings → Model Pool;
default 50% of total RAM). Least-recently-used non-pinned
models auto-evict when the cap is exceeded. Pin a model from
its row in the Models tab (pin icon) to keep it resident
regardless of LRU order. Pinned state is in-memory for this
release; persistence across restarts is a follow-up.

---

Expand Down
18 changes: 16 additions & 2 deletions MacMLXCore/Sources/MacMLXCore/Managers/SettingsManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ public struct Settings: Codable, Equatable, Sendable {
/// enforcement lands in v0.4.0.1.
public var kvCacheColdGB: Int

/// ModelPool byte budget, expressed in gigabytes (Apple's 10^9 GB
/// convention). When resident models' summed estimated footprint
/// exceeds this, the pool LRU-evicts non-pinned entries. Default
/// is 50% of the machine's physical RAM, clamped to a 4 GB floor
/// for small-memory Macs.
public var maxResidentMemoryGB: Int

// MARK: Factory

/// Sensible out-of-the-box defaults — used when no settings file exists.
Expand All @@ -81,7 +88,8 @@ public struct Settings: Codable, Equatable, Sendable {
logRetentionDays: 7,
hfEndpoint: "https://huggingface.co",
kvCacheHotMB: 512,
kvCacheColdGB: 20
kvCacheColdGB: 20,
maxResidentMemoryGB: max(4, Int(MemoryProbe.totalMemoryGB()) / 2)
)

// MARK: Init
Expand All @@ -99,7 +107,8 @@ public struct Settings: Codable, Equatable, Sendable {
logRetentionDays: Int,
hfEndpoint: String = "https://huggingface.co",
kvCacheHotMB: Int = 512,
kvCacheColdGB: Int = 20
kvCacheColdGB: Int = 20,
maxResidentMemoryGB: Int = max(4, Int(MemoryProbe.totalMemoryGB()) / 2)
) {
self.modelDirectory = modelDirectory
self.preferredEngine = preferredEngine
Expand All @@ -114,6 +123,7 @@ public struct Settings: Codable, Equatable, Sendable {
self.hfEndpoint = hfEndpoint
self.kvCacheHotMB = kvCacheHotMB
self.kvCacheColdGB = kvCacheColdGB
self.maxResidentMemoryGB = maxResidentMemoryGB
}

// MARK: - Codable (backward-compat decode)
Expand All @@ -135,6 +145,7 @@ public struct Settings: Codable, Equatable, Sendable {
case hfEndpoint
case kvCacheHotMB
case kvCacheColdGB
case maxResidentMemoryGB
}

public init(from decoder: Decoder) throws {
Expand All @@ -153,6 +164,9 @@ public struct Settings: Codable, Equatable, Sendable {
?? "https://huggingface.co"
self.kvCacheHotMB = try c.decodeIfPresent(Int.self, forKey: .kvCacheHotMB) ?? 512
self.kvCacheColdGB = try c.decodeIfPresent(Int.self, forKey: .kvCacheColdGB) ?? 20
self.maxResidentMemoryGB =
(try c.decodeIfPresent(Int.self, forKey: .maxResidentMemoryGB))
?? max(4, Int(MemoryProbe.totalMemoryGB()) / 2)
}
}

Expand Down
151 changes: 151 additions & 0 deletions MacMLXCore/Sources/MacMLXCore/ModelPool/ModelPool.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import Foundation

/// Actor managing multiple resident `InferenceEngine` instances with
/// LRU + explicit pinning + byte-budget auto-evict. Use cases:
///
/// - Swap between chat models without re-reading weights from disk
/// - External API cold-swap without unloading the GUI's current model
/// - Keep a small always-ready model pinned alongside a big one that
/// auto-evicts on memory pressure
///
/// Load path is serialised under `loadTasks` to avoid two concurrent
/// requests double-loading the same weights — the second caller awaits
/// the first's completion.
public actor ModelPool {

public typealias EngineFactory = @Sendable (LocalModel) -> any InferenceEngine

// MARK: - State

/// Currently resident engines, keyed by model ID.
private var engines: [String: any InferenceEngine] = [:]
/// Bookkeeping keyed by model ID.
private var entries: [String: PooledEngineEntry] = [:]
/// In-flight loads so concurrent callers deduplicate.
private var loadTasks: [String: Task<any InferenceEngine, Error>] = [:]

private let engineFactory: EngineFactory

// MARK: - Budget

/// Maximum total estimated bytes that may be resident. Exceeding
/// this triggers LRU eviction (pinned entries are spared).
public var maxBytes: Int64

public init(
maxBytes: Int64,
engineFactory: @escaping EngineFactory
) {
self.maxBytes = maxBytes
self.engineFactory = engineFactory
}

public func setMaxBytes(_ bytes: Int64) {
self.maxBytes = bytes
}

// MARK: - Public

public func residentModelIDs() -> [String] {
Array(engines.keys).sorted()
}

public func engine(for modelID: String) -> (any InferenceEngine)? {
guard let e = engines[modelID] else { return nil }
// Touch LRU timestamp.
if var entry = entries[modelID] {
entry.lastAccess = Date()
entries[modelID] = entry
}
return e
}

public func setPinned(_ modelID: String, _ pinned: Bool) {
guard var entry = entries[modelID] else { return }
entry.isPinned = pinned
entries[modelID] = entry
}

public func isPinned(_ modelID: String) -> Bool {
entries[modelID]?.isPinned ?? false
}

public func unload(_ modelID: String) async {
if let e = engines.removeValue(forKey: modelID) {
try? await e.unload()
}
entries.removeValue(forKey: modelID)
}

/// Return an engine with `model.id` loaded. Reuses an existing
/// entry when possible. Evicts LRU entries as needed to stay
/// within `maxBytes`. Concurrent loads of the same ID share.
@discardableResult
public func load(_ model: LocalModel) async throws -> any InferenceEngine {
// Already loaded? Touch and return.
if let e = engines[model.id] {
if var entry = entries[model.id] {
entry.lastAccess = Date()
entries[model.id] = entry
}
return e
}
// In-flight load by another caller? Join it.
if let pending = loadTasks[model.id] {
return try await pending.value
}

// Evict to fit before starting the load, using the model's
// sizeBytes (or our estimate) as the cost.
let cost = model.sizeBytes > 0 ? model.sizeBytes : estimateModelSize(at: model.directory)
evict(toFit: cost)

let factory = engineFactory
let task = Task { () throws -> any InferenceEngine in
let engine = factory(model)
try await engine.load(model)
return engine
}
loadTasks[model.id] = task
do {
let engine = try await task.value
loadTasks.removeValue(forKey: model.id)
engines[model.id] = engine
entries[model.id] = PooledEngineEntry(
modelID: model.id,
estimatedBytes: cost
)
return engine
} catch {
loadTasks.removeValue(forKey: model.id)
throw error
}
}

// MARK: - Eviction

private func currentResidentBytes() -> Int64 {
entries.values.map(\.estimatedBytes).reduce(0, +)
}

/// Evict LRU non-pinned entries until (currentBytes + incoming) fits.
private func evict(toFit incoming: Int64) {
var target = maxBytes - incoming
if target < 0 { target = 0 }

// Candidates: non-pinned, oldest first.
let candidates = entries.values
.filter { !$0.isPinned }
.sorted { $0.lastAccess < $1.lastAccess }

var current = currentResidentBytes()
var iterator = candidates.makeIterator()
while current > target, let victim = iterator.next() {
if let e = engines.removeValue(forKey: victim.modelID) {
Task { try? await e.unload() }
}
entries.removeValue(forKey: victim.modelID)
current -= victim.estimatedBytes
}
}
}
49 changes: 49 additions & 0 deletions MacMLXCore/Sources/MacMLXCore/ModelPool/PooledEngineEntry.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import Foundation

/// Bookkeeping struct held by `ModelPool` per resident model.
/// The engine itself is not stored here (it's an actor in the pool's
/// dict); this is the value-type metadata.
public struct PooledEngineEntry: Sendable, Equatable {
/// Model identifier (matches `LocalModel.id`).
public let modelID: String
/// Estimated memory cost — sum of safetensors file sizes in
/// the model directory. Rough but stable for budget math;
/// actual MLX allocator usage can exceed this by 10–30%.
public let estimatedBytes: Int64
/// Wall-clock time of last `engine(for:)` or `load(_:)` access.
public var lastAccess: Date
/// Pinned entries are never evicted by the LRU sweeper.
public var isPinned: Bool

public init(
modelID: String,
estimatedBytes: Int64,
lastAccess: Date = Date(),
isPinned: Bool = false
) {
self.modelID = modelID
self.estimatedBytes = estimatedBytes
self.lastAccess = lastAccess
self.isPinned = isPinned
}
}

/// Sum of `.safetensors` files under `directory`. Rough proxy for
/// how much memory the model needs when loaded. Returns 0 on any
/// filesystem error.
public func estimateModelSize(at directory: URL) -> Int64 {
guard let files = try? FileManager.default.contentsOfDirectory(
at: directory,
includingPropertiesForKeys: [.fileSizeKey]
) else {
return 0
}
return files
.filter { $0.pathExtension.lowercased() == "safetensors" }
.compactMap { url -> Int64? in
guard let values = try? url.resourceValues(forKeys: [.fileSizeKey]),
let size = values.fileSize else { return nil }
return Int64(size)
}
.reduce(0, +)
}
102 changes: 102 additions & 0 deletions MacMLXCore/Tests/MacMLXCoreTests/ModelPool/ModelPoolTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import XCTest
@testable import MacMLXCore

/// Stub engine for pool tests — no Metal/MLX required. Implements
/// the minimum InferenceEngine surface the pool touches: load,
/// unload, engineID. Generate throws since it shouldn't be called.
private actor StubEngine: InferenceEngine {
nonisolated let engineID: EngineID = .mlxSwift
var status: EngineStatus = .idle
var loadedModel: LocalModel?
var version: String = "stub"

func load(_ model: LocalModel) async throws {
loadedModel = model
status = .ready(model: model.id)
}

func unload() async throws {
loadedModel = nil
status = .idle
}

nonisolated func generate(_ request: GenerateRequest) -> AsyncThrowingStream<GenerateChunk, Error> {
AsyncThrowingStream { cont in
cont.finish(throwing: EngineError.modelNotLoaded)
}
}

func healthCheck() async -> Bool { true }
}

final class ModelPoolTests: XCTestCase {

private func mkModel(_ id: String, size: Int64 = 1_000_000_000) -> LocalModel {
LocalModel(
id: id,
displayName: id,
directory: FileManager.default.temporaryDirectory,
sizeBytes: size,
format: .mlx,
quantization: nil,
parameterCount: nil,
architecture: nil
)
}

func testLoadAddsToPool() async throws {
let pool = ModelPool(maxBytes: 4_000_000_000, engineFactory: { _ in StubEngine() })
let m = mkModel("A", size: 1_000_000_000)
_ = try await pool.load(m)
let residents = await pool.residentModelIDs()
XCTAssertEqual(residents, ["A"])
}

func testLoadReuseExistingInstance() async throws {
let pool = ModelPool(maxBytes: 4_000_000_000, engineFactory: { _ in StubEngine() })
let m = mkModel("A", size: 1_000_000_000)
let e1 = try await pool.load(m) as AnyObject
let e2 = try await pool.load(m) as AnyObject
XCTAssertTrue(e1 === e2)
}

func testOverBudgetEvictsLRU() async throws {
let pool = ModelPool(maxBytes: 2_500_000_000, engineFactory: { _ in StubEngine() })
_ = try await pool.load(mkModel("A", size: 1_000_000_000))
_ = try await pool.load(mkModel("B", size: 1_000_000_000))
// Budget has 2.5 GB, A+B = 2 GB fits.
_ = try await pool.load(mkModel("C", size: 1_000_000_000))
// A+B+C = 3 GB — over. Oldest (A) evicted.
let residents = await pool.residentModelIDs()
XCTAssertFalse(residents.contains("A"))
XCTAssertTrue(residents.contains("B"))
XCTAssertTrue(residents.contains("C"))
}

func testPinnedNotEvicted() async throws {
let pool = ModelPool(maxBytes: 2_500_000_000, engineFactory: { _ in StubEngine() })
_ = try await pool.load(mkModel("A", size: 1_000_000_000))
await pool.setPinned("A", true)
_ = try await pool.load(mkModel("B", size: 1_000_000_000))
_ = try await pool.load(mkModel("C", size: 1_000_000_000))
// A is pinned → B (next-oldest) evicted instead.
let residents = await pool.residentModelIDs()
XCTAssertTrue(residents.contains("A"))
XCTAssertFalse(residents.contains("B"))
XCTAssertTrue(residents.contains("C"))
}

func testUnloadRemovesFromPool() async throws {
let pool = ModelPool(maxBytes: 4_000_000_000, engineFactory: { _ in StubEngine() })
_ = try await pool.load(mkModel("A"))
await pool.unload("A")
let residents = await pool.residentModelIDs()
XCTAssertTrue(residents.isEmpty)
}

func testEngineForReturnsNilWhenNotLoaded() async {
let pool = ModelPool(maxBytes: 4_000_000_000, engineFactory: { _ in StubEngine() })
let e = await pool.engine(for: "A")
XCTAssertNil(e)
}
}
Loading