diff --git a/Sources/SwitchcraftCore/Embedding/Embedder.swift b/Sources/SwitchcraftCore/Embedding/Embedder.swift index 9bbf30b..536f902 100644 --- a/Sources/SwitchcraftCore/Embedding/Embedder.swift +++ b/Sources/SwitchcraftCore/Embedding/Embedder.swift @@ -27,7 +27,18 @@ public protocol Embedder: Sendable { /// Stable identifier for the model. Recorded on `ChunkRecord.model`. var modelIdentifier: String { get } + /// Maximum number of tokens the embedder will process in a single `encode` + /// call. Inputs that tokenise to more tokens than this limit are handled + /// according to the conformer's configured `EmbedderOverflowPolicy`. + /// + /// Conformers should expose this as a `nonisolated let` stored property so + /// callers can read the limit without entering an actor. + var maxInputTokens: Int { get } + /// Encode `text` into a flat row-major `n × dims` per-token embedding /// matrix. Returns an empty array for empty / whitespace-only text. + /// + /// - Throws: `EmbedderError.inputTooLarge(actual:max:)` when the token + /// count exceeds `maxInputTokens` and the overflow policy is `.reject`. func encode(_ text: String) async throws -> [Float] } diff --git a/Sources/SwitchcraftCore/Embedding/EmbedderError.swift b/Sources/SwitchcraftCore/Embedding/EmbedderError.swift new file mode 100644 index 0000000..588130c --- /dev/null +++ b/Sources/SwitchcraftCore/Embedding/EmbedderError.swift @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +import Foundation + +/// Errors thrown by `Embedder` implementations. +public enum EmbedderError: Error, Sendable, Equatable { + /// The tokenised input length exceeded the embedder's `maxInputTokens` limit + /// and the configured overflow policy is `.reject`. + /// + /// - Parameters: + /// - actual: The number of tokens produced by the tokenizer for the input. + /// - max: The embedder's `maxInputTokens` limit. + case inputTooLarge(actual: Int, max: Int) +} diff --git a/Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift b/Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift new file mode 100644 index 0000000..a8f0a06 --- /dev/null +++ b/Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +import Foundation + +/// Controls how an `Embedder` handles inputs whose token count exceeds `maxInputTokens`. +/// +/// The two policies match the primary consumer trade-offs for retrieval workloads: +/// +/// - `.truncate` (default): Silently clips the token sequence to the first +/// `maxInputTokens` tokens and encodes the prefix. This is the established +/// convention (Hugging Face `truncation=True, max_length=...`) and is appropriate +/// for search, classification, and bulk-index pipelines where prefix content is +/// informative and silent data loss is acceptable. +/// +/// - `.reject`: Throws `EmbedderError.inputTooLarge(actual:max:)` so the caller can +/// decide whether to skip, summarise, or split the input. Use this when silent +/// truncation would violate the application's correctness guarantees. +public enum EmbedderOverflowPolicy: Sendable, Equatable, Hashable { + /// Silently truncate the token sequence to `maxInputTokens` elements and encode + /// the prefix. No error is thrown; embeddings are returned for the truncated input. + case truncate + + /// Throw `EmbedderError.inputTooLarge(actual:max:)` without calling the underlying + /// model. The caller is responsible for splitting, summarising, or skipping the input. + case reject +} diff --git a/Sources/SwitchcraftCoreML/MLPredictor.swift b/Sources/SwitchcraftCoreML/MLPredictor.swift index 977593b..a55b001 100644 --- a/Sources/SwitchcraftCoreML/MLPredictor.swift +++ b/Sources/SwitchcraftCoreML/MLPredictor.swift @@ -16,6 +16,11 @@ internal protocol MLPredictor: Sendable { func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider } +// @unchecked @retroactive: retroactive conformances on CoreML types; safe +// because all access is gated through T5CoreMLEmbedder's actor isolation. +extension MLModel: @unchecked @retroactive Sendable {} +extension MLDictionaryFeatureProvider: @unchecked @retroactive Sendable {} + extension MLModel: MLPredictor { internal func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider { try self.prediction(from: input) diff --git a/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift b/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift index bea666b..351cdeb 100644 --- a/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift +++ b/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift @@ -87,6 +87,20 @@ public actor T5CoreMLEmbedder: Embedder { /// `windowSize`. `256` matches Witchcraft's sliding-window stride. public nonisolated let stride: Int + /// Maximum total token count across all sliding windows that the embedder + /// will accept in a single `encode` call. Inputs that tokenise to more + /// tokens than this value are handled according to `overflowPolicy`. + /// + /// Default `8 * windowSize` (4,096 for the standard 512-token window), + /// yielding ~15 windows at stride 256 — well below the ~577-window burst + /// that exhausted the ANE IOSurface pool. Must be ≥ `windowSize`. + public nonisolated let maxInputTokens: Int + + /// Controls behaviour when a tokenised input exceeds `maxInputTokens`. + /// `.truncate` (default) silently clips to the prefix; `.reject` throws + /// `EmbedderError.inputTooLarge(actual:max:)`. + public nonisolated let overflowPolicy: EmbedderOverflowPolicy + private let tokenizer: Tokenizer private var predictor: any MLPredictor /// Recreates the main predictor on demand; used by proactive reload to @@ -136,6 +150,12 @@ public actor T5CoreMLEmbedder: Embedder { /// accumulated ANE IOSurface resources. Default `150` — see the stored /// property doc-comment for tuning guidance. Existing callers that omit /// this parameter are unaffected. + /// - maxInputTokens: maximum total token count accepted per `encode` call. + /// Inputs exceeding this limit are handled by `overflowPolicy`. Default + /// `8 * windowSize`. Must be ≥ `windowSize`. + /// - overflowPolicy: `.truncate` (default) clips oversized inputs to the + /// first `maxInputTokens` tokens; `.reject` throws + /// `EmbedderError.inputTooLarge(actual:max:)`. /// - Throws: any error from `MLModel.compileModel(at:)` or `MLModel(contentsOf:)`. public init( modelURL: URL, @@ -147,13 +167,18 @@ public actor T5CoreMLEmbedder: Embedder { stride: Int = 256, minNorm: Float = 1.0, failureLogURL: URL? = nil, - reloadInterval: Int = 150 + reloadInterval: Int = 150, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) async throws { precondition(dims > 0 && dims % 2 == 0, "dims must be positive and even (Q4 codec packs two nibbles per byte)") precondition(windowSize > 0) precondition(stride > 0 && stride <= windowSize) precondition(reloadInterval > 0, "reloadInterval must be positive (used as modulo divisor)") + let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize + precondition(resolvedMaxInputTokens >= windowSize, + "maxInputTokens must be >= windowSize (got \(resolvedMaxInputTokens) < \(windowSize))") let configuration = MLModelConfiguration() configuration.computeUnits = computeUnits @@ -198,6 +223,8 @@ public actor T5CoreMLEmbedder: Embedder { self.stride = stride self.failureLogURL = failureLogURL self.reloadInterval = reloadInterval + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicy } /// Convenience init that resolves the model URL from a `Bundle`. @@ -213,7 +240,9 @@ public actor T5CoreMLEmbedder: Embedder { stride: Int = 256, minNorm: Float = 1.0, failureLogURL: URL? = nil, - reloadInterval: Int = 150 + reloadInterval: Int = 150, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) async throws { guard let url = bundle.url( forResource: resourceName, @@ -235,7 +264,9 @@ public actor T5CoreMLEmbedder: Embedder { stride: stride, minNorm: minNorm, failureLogURL: failureLogURL, - reloadInterval: reloadInterval + reloadInterval: reloadInterval, + maxInputTokens: maxInputTokens, + overflowPolicy: overflowPolicy ) } @@ -250,12 +281,17 @@ public actor T5CoreMLEmbedder: Embedder { stride: Int = 256, minNorm: Float = 1.0, modelIdentifier: String = "stub@v0", - failureLogURL: URL? = nil + failureLogURL: URL? = nil, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) { precondition(dims > 0 && dims % 2 == 0, "dims must be positive and even") precondition(windowSize > 0) precondition(stride > 0 && stride <= windowSize) + let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize + precondition(resolvedMaxInputTokens >= windowSize, + "maxInputTokens must be >= windowSize") let capturedPredictor = predictor self.predictorFactory = { capturedPredictor } self.cpuPredictorFactory = nil @@ -268,6 +304,8 @@ public actor T5CoreMLEmbedder: Embedder { self.modelIdentifier = modelIdentifier self.failureLogURL = failureLogURL self.reloadInterval = 150 + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicy } /// Test-only init: inject a factory for predictor lifecycle testing. @@ -287,13 +325,18 @@ public actor T5CoreMLEmbedder: Embedder { minNorm: Float = 1.0, modelIdentifier: String = "stub@v0", failureLogURL: URL? = nil, - reloadInterval: Int = 150 + reloadInterval: Int = 150, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) throws { precondition(dims > 0 && dims % 2 == 0, "dims must be positive and even") precondition(windowSize > 0) precondition(stride > 0 && stride <= windowSize) precondition(reloadInterval > 0, "reloadInterval must be positive (used as modulo divisor)") + let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize + precondition(resolvedMaxInputTokens >= windowSize, + "maxInputTokens must be >= windowSize") self.predictorFactory = predictorFactory self.cpuPredictorFactory = cpuPredictorFactory self.predictor = try predictorFactory() @@ -305,6 +348,8 @@ public actor T5CoreMLEmbedder: Embedder { self.modelIdentifier = modelIdentifier self.failureLogURL = failureLogURL self.reloadInterval = reloadInterval + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicy } // MARK: - Embedder @@ -317,8 +362,10 @@ public actor T5CoreMLEmbedder: Embedder { /// failures are silently retried on CPU (see class doc-comment); callers /// only receive an error if the retry also fails. /// - /// - Throws: `T5CoreMLEmbedderError.missingOutput` if the CoreML - /// model does not produce the expected feature dictionary; + /// - Throws: `EmbedderError.inputTooLarge(actual:max:)` when the token + /// count exceeds `maxInputTokens` and `overflowPolicy` is `.reject`; + /// `T5CoreMLEmbedderError.missingOutput` if the CoreML model does not + /// produce the expected feature dictionary; /// `CoreMLNativeError.nativeException` if CoreML raises an internal /// ObjC exception that the embedder cannot recover from; any /// tokenizer-originated error. @@ -343,9 +390,20 @@ public actor T5CoreMLEmbedder: Embedder { let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) if trimmed.isEmpty { return [] } - let tokens = try tokenizer.encode(text, addSpecialTokens: true) + var tokens = try tokenizer.encode(text, addSpecialTokens: true) if tokens.isEmpty { return [] } + // Overflow guard: prevent oversized inputs from generating hundreds of + // sliding windows and exhausting the ANE IOSurface buffer pool (ADR 022). + if tokens.count > maxInputTokens { + switch overflowPolicy { + case .truncate: + tokens = Array(tokens.prefix(maxInputTokens)) + case .reject: + throw EmbedderError.inputTooLarge(actual: tokens.count, max: maxInputTokens) + } + } + // Proactive model reload: recreate the predictor every reloadInterval // encodes to flush accumulated ANE IOSurface resources. // Counter increments only for real inference calls (whitespace-only inputs @@ -438,7 +496,7 @@ public actor T5CoreMLEmbedder: Embedder { /// Run one window prediction with autoreleasepool drainage, reactive reload, /// ANE retry, and IOSurface CPU fallback. private func predictWindow( - provider: any MLFeatureProvider, + provider: MLDictionaryFeatureProvider, inputLength: Int, windowTokenCount: Int ) throws -> any MLFeatureProvider { diff --git a/Sources/SwitchcraftMetal/T5MetalEmbedder.swift b/Sources/SwitchcraftMetal/T5MetalEmbedder.swift index 94afda9..6922025 100644 --- a/Sources/SwitchcraftMetal/T5MetalEmbedder.swift +++ b/Sources/SwitchcraftMetal/T5MetalEmbedder.swift @@ -65,6 +65,8 @@ public actor T5MetalEmbedder: Embedder { public nonisolated let minNorm: Float public nonisolated let windowSize: Int public nonisolated let stride: Int + public nonisolated let maxInputTokens: Int + public nonisolated let overflowPolicy: EmbedderOverflowPolicy // MARK: - Architecture constants // @@ -149,12 +151,17 @@ public actor T5MetalEmbedder: Embedder { windowSize windowSizeParam: Int = 512, stride strideParam: Int = 256, minNorm minNormParam: Float = 1.0, - modelIdentifier modelIdentifierParam: String = "google/xtr-base-en@v1+gguf" + modelIdentifier modelIdentifierParam: String = "google/xtr-base-en@v1+gguf", + maxInputTokens maxInputTokensParam: Int? = nil, + overflowPolicy overflowPolicyParam: EmbedderOverflowPolicy = .truncate ) async throws { precondition(dimsParam > 0 && dimsParam % 2 == 0, "dims must be positive and even (Q4 codec packs two nibbles per byte)") precondition(windowSizeParam > 0) precondition(strideParam > 0 && strideParam <= windowSizeParam) + let resolvedMaxInputTokens = maxInputTokensParam ?? 8 * windowSizeParam + precondition(resolvedMaxInputTokens >= windowSizeParam, + "maxInputTokens must be >= windowSize") guard let context = MetalContext.shared else { throw T5MetalEmbedderError.metalUnavailable @@ -414,6 +421,8 @@ public actor T5MetalEmbedder: Embedder { self.minNorm = minNormParam self.windowSize = windowSizeParam self.stride = strideParam + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicyParam self.tokenizer = tokenizer self.context = context self.layers = layerWeights @@ -456,9 +465,20 @@ public actor T5MetalEmbedder: Embedder { let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) if trimmed.isEmpty { return [] } - let tokens = try tokenizer.encode(text, addSpecialTokens: true) + var tokens = try tokenizer.encode(text, addSpecialTokens: true) if tokens.isEmpty { return [] } + // Overflow guard: prevent oversized inputs from generating hundreds of + // Metal command buffers and exhausting device memory (ADR 022). + if tokens.count > maxInputTokens { + switch overflowPolicy { + case .truncate: + tokens = Array(tokens.prefix(maxInputTokens)) + case .reject: + throw EmbedderError.inputTooLarge(actual: tokens.count, max: maxInputTokens) + } + } + let starts = SlidingWindow.plan( tokenCount: tokens.count, windowSize: windowSize, diff --git a/Tests/SwitchcraftTests/SearchTimeoutTests.swift b/Tests/SwitchcraftTests/SearchTimeoutTests.swift index fa57a0a..f82763a 100644 --- a/Tests/SwitchcraftTests/SearchTimeoutTests.swift +++ b/Tests/SwitchcraftTests/SearchTimeoutTests.swift @@ -55,6 +55,7 @@ struct SearchTimeoutTests { var dims: Int { inner.dims } var modelIdentifier: String { inner.modelIdentifier } + var maxInputTokens: Int { inner.maxInputTokens } func encode(_ text: String) async throws -> [Float] { // Task.sleep respects task cancellation: it throws diff --git a/Tests/SwitchcraftTests/Support/MockEmbedder.swift b/Tests/SwitchcraftTests/Support/MockEmbedder.swift index fd15c58..f7f4a6d 100644 --- a/Tests/SwitchcraftTests/Support/MockEmbedder.swift +++ b/Tests/SwitchcraftTests/Support/MockEmbedder.swift @@ -15,6 +15,7 @@ import SwitchcraftCore struct MockEmbedder: Embedder, Sendable { let dims: Int let modelIdentifier: String + let maxInputTokens: Int = Int.max init(dims: Int = 128, modelIdentifier: String? = nil) { precondition(dims > 0 && dims % 2 == 0, diff --git a/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift b/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift index 4f24545..aba9e62 100644 --- a/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift +++ b/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift @@ -339,6 +339,7 @@ struct SwitchcraftStoreTests { struct OddDimsEmbedder: Embedder { let dims = 33 let modelIdentifier = "odd" + let maxInputTokens: Int = Int.max func encode(_ text: String) async throws -> [Float] { [] } } diff --git a/Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift b/Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift new file mode 100644 index 0000000..0af1a1d --- /dev/null +++ b/Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: Apache-2.0 +import Foundation +import Testing +import SwitchcraftCore +@testable import SwitchcraftCoreML + +#if canImport(CoreML) +import CoreML + +/// Tests for `T5CoreMLEmbedder`'s overflow guard (`maxInputTokens` / `overflowPolicy`). +/// +/// No CoreML model asset is required — `CountingStubPredictor` is injected +/// via the factory-based internal init. All tests use a small `windowSize` and +/// `dims` to keep per-iteration allocations tiny. +@Suite("T5CoreMLEmbedder overflow guard") +struct T5CoreMLEmbedderOverflowTests { + + // MARK: - Helpers + + private static func makeTokenizer() throws -> Tokenizer { + let url = try #require( + Bundle.module.url( + forResource: "xtr-base-en.tokenizer", + withExtension: "json", + subdirectory: "Fixtures" + ), + "tokenizer fixture missing from test bundle" + ) + return try Tokenizer(contentsOf: url.path) + } + + /// Returns a string that reliably tokenises to more than `minTokens` tokens. + /// Repeats "hello " enough times that the BPE tokenizer (plus the T5 + /// sentinel) always produces more than `minTokens` tokens. + private static func oversizedInput(minTokens: Int) -> String { + // "hello" is a single BPE token in xtr-base-en; " " is a word-start + // prefix on the following token. Repeating "hello " minTokens+5 times + // reliably produces more than minTokens tokens after tokenisation. + return String(repeating: "hello ", count: minTokens + 5) + } + + // MARK: - R7d: maxInputTokens property + + @Test("maxInputTokens property returns the configured value") + func testMaxInputTokensPropertyReturnsConfiguredValue() throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let customLimit = 24 // >= windowSize + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: customLimit + ) + + #expect(embedder.maxInputTokens == customLimit) + } + + // MARK: - R7a: truncate policy + + @Test("Truncate policy encodes oversized input without error and returns non-empty embeddings") + func testTruncatePolicyEncodesOversizedInputWithoutError() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 // >= windowSize; guarantees at least 1 full window + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .truncate + ) + + // Build an input that definitely exceeds maxTokens. + let input = Self.oversizedInput(minTokens: maxTokens) + let result = try await embedder.encode(input) + + // Truncation succeeds: the result must be non-empty and at most + // maxTokens rows wide (each row is `dims` floats). + #expect(!result.isEmpty, "truncated encode must return non-empty embeddings") + #expect(result.count <= maxTokens * dims, + "result has \(result.count) floats; expected ≤ \(maxTokens * dims) for \(maxTokens) tokens × \(dims) dims") + } + + // MARK: - R7b: reject policy + + @Test("Reject policy throws EmbedderError.inputTooLarge with correct actual and max values") + func testRejectPolicyThrowsInputTooLargeError() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .reject + ) + + // Count the tokens the tokenizer actually produces so we can verify + // the `actual` field on the thrown error. + let input = Self.oversizedInput(minTokens: maxTokens) + let actualTokenCount = try tokenizer.encode(input, addSpecialTokens: true).count + try #require(actualTokenCount > maxTokens, + "test precondition: input must tokenise to > \(maxTokens) tokens; got \(actualTokenCount)") + + await #expect(throws: EmbedderError.inputTooLarge(actual: actualTokenCount, max: maxTokens)) { + _ = try await embedder.encode(input) + } + } + + // MARK: - R5, R7c: stress test + + @Test("1,000 encode calls with oversized inputs interleaved complete without ANE pool degradation (truncate policy)") + func testOverflowStress1000CallsTruncatePolicy() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .truncate + ) + + let normalInput = "semantic search" + let oversizedInput = Self.oversizedInput(minTokens: maxTokens) + + for i in 0..<1_000 { + let text = (i == 250 || i == 750) ? oversizedInput : normalInput + // All calls must succeed (oversized inputs are silently truncated). + let result = try await embedder.encode(text) + #expect(!result.isEmpty, "encode returned empty at index \(i)") + } + } + + @Test("1,000 encode calls: reject policy throws on oversized inputs but normal-sized calls always succeed") + func testOverflowStress1000CallsRejectPolicy() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .reject + ) + + let normalInput = "semantic search" + let oversizedInput = Self.oversizedInput(minTokens: maxTokens) + + var oversizedErrors = 0 + for i in 0..<1_000 { + if i == 250 || i == 750 { + // Must throw inputTooLarge — and must NOT leave the embedder in a + // broken state (subsequent normal calls must still work). + do { + _ = try await embedder.encode(oversizedInput) + Issue.record("Expected EmbedderError.inputTooLarge at index \(i) but no error was thrown") + } catch let err as EmbedderError { + if case .inputTooLarge = err { + oversizedErrors += 1 + } else { + Issue.record("Unexpected EmbedderError at index \(i): \(err)") + } + } catch { + Issue.record("Unexpected error type at index \(i): \(error)") + } + } else { + // Normal-sized call must always succeed. + let result = try await embedder.encode(normalInput) + #expect(!result.isEmpty, "normal encode returned empty at index \(i)") + } + } + + #expect(oversizedErrors == 2, "expected 2 inputTooLarge errors (indices 250 and 750), got \(oversizedErrors)") + } +} + +#endif diff --git a/adrs/022-embedder-overflow-guard.md b/adrs/022-embedder-overflow-guard.md new file mode 100644 index 0000000..3def3a2 --- /dev/null +++ b/adrs/022-embedder-overflow-guard.md @@ -0,0 +1,94 @@ +# ADR 022: Embedder overflow guard — `maxInputTokens` and `EmbedderOverflowPolicy` + +**Status:** Accepted +**Date:** 2026-05-06 +**Issue:** [#89](https://github.com/totalslacker/switchcraft/issues/89) + +## Context + +`T5CoreMLEmbedder` uses a sliding-window strategy (ADR 011) to handle inputs longer than +the CoreML model's fixed 512-token input. For a 593,285-character page that tokenises to +~148,000 tokens, `SlidingWindow.plan` generates approximately 577 windows. 577 consecutive +CoreML predictions within a single `encode()` call exhausted the ANE IOSurface buffer +pool, leaving the embedder in a state where every subsequent inference — including +small inputs — failed with `MLE5OutputPortBinder bindAndReturnError`. Process restart +was required to recover. + +ADR 021 adds three reactive/proactive defence layers (autoreleasepool drainage, proactive +model reload, CPU fallback). This ADR adds the structural prevention layer: refuse to +generate a window burst this large in the first place. + +## Decision + +### 1. `EmbedderOverflowPolicy` and `EmbedderError` live in `SwitchcraftCore` + +Both types are pure Swift with no CoreML or Metal dependency. Placing them in +`SwitchcraftCore` (not `SwitchcraftCoreML`) makes them: + +- Available to any `Embedder` conformer, including `T5MetalEmbedder` and future backends. +- Reachable by callers who hold an `any Embedder` reference without downcasting. +- Not gated behind `#if canImport(CoreML)`. + +### 2. `maxInputTokens: Int { get }` is added to the `Embedder` protocol + +Adding the limit to the protocol rather than only to concrete types: + +- Makes the safety contract queryable on any `Embedder`-typed reference. +- Forces new conformers to declare an explicit limit — an unknown embedder with no + stated limit is unsafe for bulk-index consumers. +- Follows the existing pattern where `dims` and `modelIdentifier` are protocol-level. + +Conformers are encouraged (but not required by the compiler) to implement this as a +`nonisolated let` stored property so callers can read it without entering an actor. + +### 3. Default `maxInputTokens = 8 * windowSize` + +For the standard 512-token window, this is 4,096 total tokens → ~15 windows at stride 256. +The multiplier form scales naturally when `windowSize` is customised at init. 4,096 covers +most real-world documents (~3,000 words) while keeping the window count well below the +threshold observed to exhaust the ANE pool (~577). + +Callers running in environments with larger IOSurface budgets (or no ANE) may pass a larger +value at init. A `precondition(maxInputTokens >= windowSize)` guards against setting the +limit below one window, which would make every non-trivial encode either truncate to +padding or always throw. + +### 4. Two policies: `.truncate` (default) and `.reject` + +`.truncate` is the default for three reasons: +- It matches the established HuggingFace convention (`truncation=True, max_length=...`). +- The primary use case (retrieval, classification, bulk indexing) benefits from prefix + embeddings when the full document is too long. +- It preserves backward compatibility — existing callers that do not pass `overflowPolicy` + get silent truncation rather than a new throw path. + +`.reject` is provided for callers that cannot accept silent data loss (e.g., applications +that want to split long documents themselves and know exactly which text was embedded). + +### 5. `T5MetalEmbedder` gains the same guard + +`T5MetalEmbedder` uses `SlidingWindow` identically to `T5CoreMLEmbedder`. While it lacks +the IOSurface failure mode, a 577-Metal-command-buffer burst from one `encode` call is both +slow and memory-intensive. With `maxInputTokens` now on the protocol, adding the guard to +the Metal embedder is both consistent and low-risk. No stub-based tests are added for the +Metal path because `T5MetalEmbedder.init` requires a real Metal device and GGUF asset; +the CoreML stub tests provide adequate algorithmic coverage of the overflow guard logic. + +## Guard placement + +The overflow guard is inserted in `encode(_:)` after `tokenizer.encode(text, addSpecialTokens: true)` +and before `SlidingWindow.plan`. This is the earliest point where the authoritative token +count is known. It fires inside the re-entrancy guard's `inFlight` window, so the `defer` +block that releases the next waiter still executes correctly when `.reject` throws. + +## Consequences + +- **Breaking change:** `Embedder` protocol gains a required property `maxInputTokens: Int { get }`. + All existing conformers must implement it. Test mocks and local test stubs return `Int.max`. +- `T5CoreMLEmbedder` and `T5MetalEmbedder` gain two new init parameters + (`maxInputTokens: Int? = nil`, `overflowPolicy: EmbedderOverflowPolicy = .truncate`) + with defaults that maintain backward compatibility. +- `EmbedderError` is a new public enum in `SwitchcraftCore`. It will accumulate future + embedder-level errors (not model-specific errors, which remain on `T5CoreMLEmbedderError`). +- Bulk-index consumers indexing very large documents will silently receive prefix embeddings + unless they explicitly configure `.reject`. This is documented on the `encode` method. diff --git a/docs/Plan.md b/docs/Plan.md index 6b255dc..c32ccb2 100644 --- a/docs/Plan.md +++ b/docs/Plan.md @@ -318,6 +318,7 @@ Track progress by checking off items as they land. Effort estimates and notes fo - [x] **T5 encoder → CoreML** (1-2 weeks) — Model conversion + Swift wrapper - [x] **T5CoreMLEmbedder crash-safety** (#78) — `MLExceptionCatcher` ObjC `@try/@catch` bridge (separate `SwitchcraftCoreMLObjC` clang target per ADR 018); `catchingNSException` Swift facade converts CoreML internal `NSException`s to `CoreMLNativeError.nativeException` typed errors; `MLPredictor` internal protocol DI seam; re-entrancy guard (`inFlight` + `waiters` continuations); optional `failureLogURL` for JSONL crash telemetry + `os.Logger` logging - [x] **T5CoreMLEmbedder ANE IOSurface fix** (#87) — `autoreleasepool` per window, proactive model reload every `reloadInterval` encodes (default 500, tunable), reactive CPU-fallback with JSONL recovery telemetry (`"recovered_iosurface_exhaustion"`); stub stress test (5k iterations, always-on CI) + real-asset stress test (10k iterations, asset-gated); ADR 021 +- [x] **Embedder overflow guard** (#89) — `maxInputTokens: Int` added to `Embedder` protocol; `EmbedderOverflowPolicy` (`.truncate` / `.reject`) + `EmbedderError.inputTooLarge` in `SwitchcraftCore`; overflow guard in `T5CoreMLEmbedder.encode(_:)` and `T5MetalEmbedder.encode(_:)` between tokenization and `SlidingWindow.plan`; default `8 * windowSize` (4,096 tokens → ~15 windows); prevents ANE pool poisoning from oversized inputs; ADR 022 - [x] **T5CoreMLEmbedder ANE IOSurface mitigation hardening** (#90) — Fix three bugs in post-#88 production code: (1) silent CPU fallback gap: new `"cpu_fallback_failed"` JSONL category with `cpuErrorName`/`cpuErrorReason`/`cpuCallStack` fields; (2) `reloadInterval` default lowered 500→150 (below observed 388-call production failure point); (3) reactive reload + ANE retry added to Layer 3 before CPU fallback; per-window timing via `os.Logger`; Scenario A/B/C mock tests; ADR 021 amended - [x] **K-means clustering** (1 week) — Standard algorithm, use Accelerate - [x] **4-bit residual codec** (1 week) — ~200 lines, bit-level packing; round-trip property tests