Skip to content
Merged
11 changes: 11 additions & 0 deletions Sources/SwitchcraftCore/Embedding/Embedder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@ public protocol Embedder: Sendable {
/// Stable identifier for the model. Recorded on `ChunkRecord.model`.
var modelIdentifier: String { get }

/// Maximum number of tokens the embedder will process in a single `encode`
/// call. Inputs that tokenise to more tokens than this limit are handled
/// according to the conformer's configured `EmbedderOverflowPolicy`.
///
/// Conformers should expose this as a `nonisolated let` stored property so
/// callers can read the limit without entering an actor.
var maxInputTokens: Int { get }

/// Encode `text` into a flat row-major `n × dims` per-token embedding
/// matrix. Returns an empty array for empty / whitespace-only text.
///
/// - Throws: `EmbedderError.inputTooLarge(actual:max:)` when the token
/// count exceeds `maxInputTokens` and the overflow policy is `.reject`.
func encode(_ text: String) async throws -> [Float]
}
13 changes: 13 additions & 0 deletions Sources/SwitchcraftCore/Embedding/EmbedderError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
import Foundation

/// Errors thrown by `Embedder` implementations.
public enum EmbedderError: Error, Sendable, Equatable {
/// The tokenised input length exceeded the embedder's `maxInputTokens` limit
/// and the configured overflow policy is `.reject`.
///
/// - Parameters:
/// - actual: The number of tokens produced by the tokenizer for the input.
/// - max: The embedder's `maxInputTokens` limit.
case inputTooLarge(actual: Int, max: Int)
}
25 changes: 25 additions & 0 deletions Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// SPDX-License-Identifier: Apache-2.0
import Foundation

/// Controls how an `Embedder` handles inputs whose token count exceeds `maxInputTokens`.
///
/// The two policies match the primary consumer trade-offs for retrieval workloads:
///
/// - `.truncate` (default): Silently clips the token sequence to the first
/// `maxInputTokens` tokens and encodes the prefix. This is the established
/// convention (Hugging Face `truncation=True, max_length=...`) and is appropriate
/// for search, classification, and bulk-index pipelines where prefix content is
/// informative and silent data loss is acceptable.
///
/// - `.reject`: Throws `EmbedderError.inputTooLarge(actual:max:)` so the caller can
/// decide whether to skip, summarise, or split the input. Use this when silent
/// truncation would violate the application's correctness guarantees.
public enum EmbedderOverflowPolicy: Sendable, Equatable, Hashable {
/// Silently truncate the token sequence to `maxInputTokens` elements and encode
/// the prefix. No error is thrown; embeddings are returned for the truncated input.
case truncate

/// Throw `EmbedderError.inputTooLarge(actual:max:)` without calling the underlying
/// model. The caller is responsible for splitting, summarising, or skipping the input.
case reject
}
5 changes: 5 additions & 0 deletions Sources/SwitchcraftCoreML/MLPredictor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ internal protocol MLPredictor: Sendable {
func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider
}

// @unchecked @retroactive: retroactive conformances on CoreML types; safe
// because all access is gated through T5CoreMLEmbedder's actor isolation.
extension MLModel: @unchecked @retroactive Sendable {}
extension MLDictionaryFeatureProvider: @unchecked @retroactive Sendable {}

extension MLModel: MLPredictor {
internal func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider {
try self.prediction(from: input)
Expand Down
76 changes: 67 additions & 9 deletions Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ public actor T5CoreMLEmbedder: Embedder {
/// `windowSize`. `256` matches Witchcraft's sliding-window stride.
public nonisolated let stride: Int

/// Maximum total token count across all sliding windows that the embedder
/// will accept in a single `encode` call. Inputs that tokenise to more
/// tokens than this value are handled according to `overflowPolicy`.
///
/// Default `8 * windowSize` (4,096 for the standard 512-token window),
/// yielding ~15 windows at stride 256 — well below the ~577-window burst
/// that exhausted the ANE IOSurface pool. Must be ≥ `windowSize`.
public nonisolated let maxInputTokens: Int

/// Controls behaviour when a tokenised input exceeds `maxInputTokens`.
/// `.truncate` (default) silently clips to the prefix; `.reject` throws
/// `EmbedderError.inputTooLarge(actual:max:)`.
public nonisolated let overflowPolicy: EmbedderOverflowPolicy

private let tokenizer: Tokenizer
private var predictor: any MLPredictor
/// Recreates the main predictor on demand; used by proactive reload to
Expand Down Expand Up @@ -136,6 +150,12 @@ public actor T5CoreMLEmbedder: Embedder {
/// accumulated ANE IOSurface resources. Default `150` — see the stored
/// property doc-comment for tuning guidance. Existing callers that omit
/// this parameter are unaffected.
/// - maxInputTokens: maximum total token count accepted per `encode` call.
/// Inputs exceeding this limit are handled by `overflowPolicy`. Default
/// `8 * windowSize`. Must be ≥ `windowSize`.
/// - overflowPolicy: `.truncate` (default) clips oversized inputs to the
/// first `maxInputTokens` tokens; `.reject` throws
/// `EmbedderError.inputTooLarge(actual:max:)`.
/// - Throws: any error from `MLModel.compileModel(at:)` or `MLModel(contentsOf:)`.
public init(
modelURL: URL,
Expand All @@ -147,13 +167,18 @@ public actor T5CoreMLEmbedder: Embedder {
stride: Int = 256,
minNorm: Float = 1.0,
failureLogURL: URL? = nil,
reloadInterval: Int = 150
reloadInterval: Int = 150,
maxInputTokens: Int? = nil,
overflowPolicy: EmbedderOverflowPolicy = .truncate
) async throws {
precondition(dims > 0 && dims % 2 == 0,
"dims must be positive and even (Q4 codec packs two nibbles per byte)")
precondition(windowSize > 0)
precondition(stride > 0 && stride <= windowSize)
precondition(reloadInterval > 0, "reloadInterval must be positive (used as modulo divisor)")
let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize
precondition(resolvedMaxInputTokens >= windowSize,
"maxInputTokens must be >= windowSize (got \(resolvedMaxInputTokens) < \(windowSize))")

let configuration = MLModelConfiguration()
configuration.computeUnits = computeUnits
Expand Down Expand Up @@ -198,6 +223,8 @@ public actor T5CoreMLEmbedder: Embedder {
self.stride = stride
self.failureLogURL = failureLogURL
self.reloadInterval = reloadInterval
self.maxInputTokens = resolvedMaxInputTokens
self.overflowPolicy = overflowPolicy
}

/// Convenience init that resolves the model URL from a `Bundle`.
Expand All @@ -213,7 +240,9 @@ public actor T5CoreMLEmbedder: Embedder {
stride: Int = 256,
minNorm: Float = 1.0,
failureLogURL: URL? = nil,
reloadInterval: Int = 150
reloadInterval: Int = 150,
maxInputTokens: Int? = nil,
overflowPolicy: EmbedderOverflowPolicy = .truncate
) async throws {
guard let url = bundle.url(
forResource: resourceName,
Expand All @@ -235,7 +264,9 @@ public actor T5CoreMLEmbedder: Embedder {
stride: stride,
minNorm: minNorm,
failureLogURL: failureLogURL,
reloadInterval: reloadInterval
reloadInterval: reloadInterval,
maxInputTokens: maxInputTokens,
overflowPolicy: overflowPolicy
)
}

Expand All @@ -250,12 +281,17 @@ public actor T5CoreMLEmbedder: Embedder {
stride: Int = 256,
minNorm: Float = 1.0,
modelIdentifier: String = "stub@v0",
failureLogURL: URL? = nil
failureLogURL: URL? = nil,
maxInputTokens: Int? = nil,
overflowPolicy: EmbedderOverflowPolicy = .truncate
) {
precondition(dims > 0 && dims % 2 == 0,
"dims must be positive and even")
precondition(windowSize > 0)
precondition(stride > 0 && stride <= windowSize)
let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize
precondition(resolvedMaxInputTokens >= windowSize,
"maxInputTokens must be >= windowSize")
let capturedPredictor = predictor
self.predictorFactory = { capturedPredictor }
self.cpuPredictorFactory = nil
Expand All @@ -268,6 +304,8 @@ public actor T5CoreMLEmbedder: Embedder {
self.modelIdentifier = modelIdentifier
self.failureLogURL = failureLogURL
self.reloadInterval = 150
self.maxInputTokens = resolvedMaxInputTokens
self.overflowPolicy = overflowPolicy
}

/// Test-only init: inject a factory for predictor lifecycle testing.
Expand All @@ -287,13 +325,18 @@ public actor T5CoreMLEmbedder: Embedder {
minNorm: Float = 1.0,
modelIdentifier: String = "stub@v0",
failureLogURL: URL? = nil,
reloadInterval: Int = 150
reloadInterval: Int = 150,
maxInputTokens: Int? = nil,
overflowPolicy: EmbedderOverflowPolicy = .truncate
) throws {
precondition(dims > 0 && dims % 2 == 0,
"dims must be positive and even")
precondition(windowSize > 0)
precondition(stride > 0 && stride <= windowSize)
precondition(reloadInterval > 0, "reloadInterval must be positive (used as modulo divisor)")
let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize
precondition(resolvedMaxInputTokens >= windowSize,
"maxInputTokens must be >= windowSize")
self.predictorFactory = predictorFactory
self.cpuPredictorFactory = cpuPredictorFactory
self.predictor = try predictorFactory()
Expand All @@ -305,6 +348,8 @@ public actor T5CoreMLEmbedder: Embedder {
self.modelIdentifier = modelIdentifier
self.failureLogURL = failureLogURL
self.reloadInterval = reloadInterval
self.maxInputTokens = resolvedMaxInputTokens
self.overflowPolicy = overflowPolicy
}

// MARK: - Embedder
Expand All @@ -317,8 +362,10 @@ public actor T5CoreMLEmbedder: Embedder {
/// failures are silently retried on CPU (see class doc-comment); callers
/// only receive an error if the retry also fails.
///
/// - Throws: `T5CoreMLEmbedderError.missingOutput` if the CoreML
/// model does not produce the expected feature dictionary;
/// - Throws: `EmbedderError.inputTooLarge(actual:max:)` when the token
/// count exceeds `maxInputTokens` and `overflowPolicy` is `.reject`;
/// `T5CoreMLEmbedderError.missingOutput` if the CoreML model does not
/// produce the expected feature dictionary;
/// `CoreMLNativeError.nativeException` if CoreML raises an internal
/// ObjC exception that the embedder cannot recover from; any
/// tokenizer-originated error.
Expand All @@ -343,9 +390,20 @@ public actor T5CoreMLEmbedder: Embedder {
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
if trimmed.isEmpty { return [] }

let tokens = try tokenizer.encode(text, addSpecialTokens: true)
var tokens = try tokenizer.encode(text, addSpecialTokens: true)
if tokens.isEmpty { return [] }

// Overflow guard: prevent oversized inputs from generating hundreds of
// sliding windows and exhausting the ANE IOSurface buffer pool (ADR 022).
if tokens.count > maxInputTokens {
switch overflowPolicy {
case .truncate:
tokens = Array(tokens.prefix(maxInputTokens))
case .reject:
throw EmbedderError.inputTooLarge(actual: tokens.count, max: maxInputTokens)
}
}

// Proactive model reload: recreate the predictor every reloadInterval
// encodes to flush accumulated ANE IOSurface resources.
// Counter increments only for real inference calls (whitespace-only inputs
Expand Down Expand Up @@ -438,7 +496,7 @@ public actor T5CoreMLEmbedder: Embedder {
/// Run one window prediction with autoreleasepool drainage, reactive reload,
/// ANE retry, and IOSurface CPU fallback.
private func predictWindow(
provider: any MLFeatureProvider,
provider: MLDictionaryFeatureProvider,
inputLength: Int,
windowTokenCount: Int
) throws -> any MLFeatureProvider {
Expand Down
24 changes: 22 additions & 2 deletions Sources/SwitchcraftMetal/T5MetalEmbedder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public actor T5MetalEmbedder: Embedder {
public nonisolated let minNorm: Float
public nonisolated let windowSize: Int
public nonisolated let stride: Int
public nonisolated let maxInputTokens: Int
public nonisolated let overflowPolicy: EmbedderOverflowPolicy

// MARK: - Architecture constants
//
Expand Down Expand Up @@ -149,12 +151,17 @@ public actor T5MetalEmbedder: Embedder {
windowSize windowSizeParam: Int = 512,
stride strideParam: Int = 256,
minNorm minNormParam: Float = 1.0,
modelIdentifier modelIdentifierParam: String = "google/xtr-base-en@v1+gguf"
modelIdentifier modelIdentifierParam: String = "google/xtr-base-en@v1+gguf",
maxInputTokens maxInputTokensParam: Int? = nil,
overflowPolicy overflowPolicyParam: EmbedderOverflowPolicy = .truncate
) async throws {
precondition(dimsParam > 0 && dimsParam % 2 == 0,
"dims must be positive and even (Q4 codec packs two nibbles per byte)")
precondition(windowSizeParam > 0)
precondition(strideParam > 0 && strideParam <= windowSizeParam)
let resolvedMaxInputTokens = maxInputTokensParam ?? 8 * windowSizeParam
precondition(resolvedMaxInputTokens >= windowSizeParam,
"maxInputTokens must be >= windowSize")

guard let context = MetalContext.shared else {
throw T5MetalEmbedderError.metalUnavailable
Expand Down Expand Up @@ -414,6 +421,8 @@ public actor T5MetalEmbedder: Embedder {
self.minNorm = minNormParam
self.windowSize = windowSizeParam
self.stride = strideParam
self.maxInputTokens = resolvedMaxInputTokens
self.overflowPolicy = overflowPolicyParam
self.tokenizer = tokenizer
self.context = context
self.layers = layerWeights
Expand Down Expand Up @@ -456,9 +465,20 @@ public actor T5MetalEmbedder: Embedder {
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
if trimmed.isEmpty { return [] }

let tokens = try tokenizer.encode(text, addSpecialTokens: true)
var tokens = try tokenizer.encode(text, addSpecialTokens: true)
if tokens.isEmpty { return [] }

// Overflow guard: prevent oversized inputs from generating hundreds of
// Metal command buffers and exhausting device memory (ADR 022).
if tokens.count > maxInputTokens {
switch overflowPolicy {
case .truncate:
tokens = Array(tokens.prefix(maxInputTokens))
case .reject:
throw EmbedderError.inputTooLarge(actual: tokens.count, max: maxInputTokens)
}
}

let starts = SlidingWindow.plan(
tokenCount: tokens.count,
windowSize: windowSize,
Expand Down
1 change: 1 addition & 0 deletions Tests/SwitchcraftTests/SearchTimeoutTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct SearchTimeoutTests {

var dims: Int { inner.dims }
var modelIdentifier: String { inner.modelIdentifier }
var maxInputTokens: Int { inner.maxInputTokens }

func encode(_ text: String) async throws -> [Float] {
// Task.sleep respects task cancellation: it throws
Expand Down
1 change: 1 addition & 0 deletions Tests/SwitchcraftTests/Support/MockEmbedder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import SwitchcraftCore
struct MockEmbedder: Embedder, Sendable {
let dims: Int
let modelIdentifier: String
let maxInputTokens: Int = Int.max

init(dims: Int = 128, modelIdentifier: String? = nil) {
precondition(dims > 0 && dims % 2 == 0,
Expand Down
1 change: 1 addition & 0 deletions Tests/SwitchcraftTests/SwitchcraftStoreTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ struct SwitchcraftStoreTests {
struct OddDimsEmbedder: Embedder {
let dims = 33
let modelIdentifier = "odd"
let maxInputTokens: Int = Int.max
func encode(_ text: String) async throws -> [Float] { [] }
}

Expand Down
Loading
Loading