From 7a6115fa6ead2c149149d846efb6626828cf53f4 Mon Sep 17 00:00:00 2001 From: Shannon Holland Date: Wed, 6 May 2026 08:21:32 -0700 Subject: [PATCH 1/5] feat(embedder): add overflow guard with configurable policy to prevent ANE IOSurface pool poisoning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `maxInputTokens` + `EmbedderOverflowPolicy` to `T5CoreMLEmbedder` and `T5MetalEmbedder`. Inputs exceeding the limit are silently truncated (.truncate, default) or throw `EmbedderError.inputTooLarge(actual:max:)` (.reject). Default limit is 8 * windowSize (4,096 tokens), yielding ~15 windows — well below the ~577-window burst that exhausted the ANE IOSurface pool. Closes #89. - New `EmbedderOverflowPolicy` and `EmbedderError` enums in SwitchcraftCore - `maxInputTokens: Int { get }` added to the `Embedder` protocol - Overflow guard in both embedder `encode(_:)` implementations - Four new tests: property value, truncate policy, reject policy, 1,000-call stress test - All existing Embedder conformers (MockEmbedder, test-local structs) updated Co-Authored-By: Claude Sonnet 4.6 --- .../SwitchcraftCore/Embedding/Embedder.swift | 11 + .../Embedding/EmbedderError.swift | 13 ++ .../Embedding/EmbedderOverflowPolicy.swift | 25 +++ .../SwitchcraftCoreML/T5CoreMLEmbedder.swift | 74 ++++++- .../SwitchcraftMetal/T5MetalEmbedder.swift | 24 +- .../SwitchcraftTests/SearchTimeoutTests.swift | 1 + .../Support/MockEmbedder.swift | 1 + .../SwitchcraftStoreTests.swift | 1 + .../T5CoreMLEmbedderOverflowTests.swift | 207 ++++++++++++++++++ 9 files changed, 347 insertions(+), 10 deletions(-) create mode 100644 Sources/SwitchcraftCore/Embedding/EmbedderError.swift create mode 100644 Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift create mode 100644 Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift diff --git a/Sources/SwitchcraftCore/Embedding/Embedder.swift b/Sources/SwitchcraftCore/Embedding/Embedder.swift index 9bbf30b..536f902 100644 --- a/Sources/SwitchcraftCore/Embedding/Embedder.swift +++ b/Sources/SwitchcraftCore/Embedding/Embedder.swift @@ -27,7 +27,18 @@ public protocol Embedder: Sendable { /// Stable identifier for the model. Recorded on `ChunkRecord.model`. var modelIdentifier: String { get } + /// Maximum number of tokens the embedder will process in a single `encode` + /// call. Inputs that tokenise to more tokens than this limit are handled + /// according to the conformer's configured `EmbedderOverflowPolicy`. + /// + /// Conformers should expose this as a `nonisolated let` stored property so + /// callers can read the limit without entering an actor. + var maxInputTokens: Int { get } + /// Encode `text` into a flat row-major `n × dims` per-token embedding /// matrix. Returns an empty array for empty / whitespace-only text. + /// + /// - Throws: `EmbedderError.inputTooLarge(actual:max:)` when the token + /// count exceeds `maxInputTokens` and the overflow policy is `.reject`. func encode(_ text: String) async throws -> [Float] } diff --git a/Sources/SwitchcraftCore/Embedding/EmbedderError.swift b/Sources/SwitchcraftCore/Embedding/EmbedderError.swift new file mode 100644 index 0000000..588130c --- /dev/null +++ b/Sources/SwitchcraftCore/Embedding/EmbedderError.swift @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +import Foundation + +/// Errors thrown by `Embedder` implementations. +public enum EmbedderError: Error, Sendable, Equatable { + /// The tokenised input length exceeded the embedder's `maxInputTokens` limit + /// and the configured overflow policy is `.reject`. + /// + /// - Parameters: + /// - actual: The number of tokens produced by the tokenizer for the input. + /// - max: The embedder's `maxInputTokens` limit. + case inputTooLarge(actual: Int, max: Int) +} diff --git a/Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift b/Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift new file mode 100644 index 0000000..a8f0a06 --- /dev/null +++ b/Sources/SwitchcraftCore/Embedding/EmbedderOverflowPolicy.swift @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +import Foundation + +/// Controls how an `Embedder` handles inputs whose token count exceeds `maxInputTokens`. +/// +/// The two policies match the primary consumer trade-offs for retrieval workloads: +/// +/// - `.truncate` (default): Silently clips the token sequence to the first +/// `maxInputTokens` tokens and encodes the prefix. This is the established +/// convention (Hugging Face `truncation=True, max_length=...`) and is appropriate +/// for search, classification, and bulk-index pipelines where prefix content is +/// informative and silent data loss is acceptable. +/// +/// - `.reject`: Throws `EmbedderError.inputTooLarge(actual:max:)` so the caller can +/// decide whether to skip, summarise, or split the input. Use this when silent +/// truncation would violate the application's correctness guarantees. +public enum EmbedderOverflowPolicy: Sendable, Equatable, Hashable { + /// Silently truncate the token sequence to `maxInputTokens` elements and encode + /// the prefix. No error is thrown; embeddings are returned for the truncated input. + case truncate + + /// Throw `EmbedderError.inputTooLarge(actual:max:)` without calling the underlying + /// model. The caller is responsible for splitting, summarising, or skipping the input. + case reject +} diff --git a/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift b/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift index 0b023cc..1a102df 100644 --- a/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift +++ b/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift @@ -83,6 +83,20 @@ public actor T5CoreMLEmbedder: Embedder { /// `windowSize`. `256` matches Witchcraft's sliding-window stride. public nonisolated let stride: Int + /// Maximum total token count across all sliding windows that the embedder + /// will accept in a single `encode` call. Inputs that tokenise to more + /// tokens than this value are handled according to `overflowPolicy`. + /// + /// Default `8 * windowSize` (4,096 for the standard 512-token window), + /// yielding ~15 windows at stride 256 — well below the ~577-window burst + /// that exhausted the ANE IOSurface pool. Must be ≥ `windowSize`. + public nonisolated let maxInputTokens: Int + + /// Controls behaviour when a tokenised input exceeds `maxInputTokens`. + /// `.truncate` (default) silently clips to the prefix; `.reject` throws + /// `EmbedderError.inputTooLarge(actual:max:)`. + public nonisolated let overflowPolicy: EmbedderOverflowPolicy + private let tokenizer: Tokenizer private var predictor: any MLPredictor /// Recreates the main predictor on demand; used by proactive reload to @@ -132,6 +146,12 @@ public actor T5CoreMLEmbedder: Embedder { /// accumulated ANE IOSurface resources. Default `500` — see the stored /// property doc-comment for tuning guidance. Existing callers that omit /// this parameter are unaffected. + /// - maxInputTokens: maximum total token count accepted per `encode` call. + /// Inputs exceeding this limit are handled by `overflowPolicy`. Default + /// `8 * windowSize`. Must be ≥ `windowSize`. + /// - overflowPolicy: `.truncate` (default) clips oversized inputs to the + /// first `maxInputTokens` tokens; `.reject` throws + /// `EmbedderError.inputTooLarge(actual:max:)`. /// - Throws: any error from `MLModel.compileModel(at:)` or `MLModel(contentsOf:)`. public init( modelURL: URL, @@ -143,13 +163,18 @@ public actor T5CoreMLEmbedder: Embedder { stride: Int = 256, minNorm: Float = 1.0, failureLogURL: URL? = nil, - reloadInterval: Int = 500 + reloadInterval: Int = 500, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) async throws { precondition(dims > 0 && dims % 2 == 0, "dims must be positive and even (Q4 codec packs two nibbles per byte)") precondition(windowSize > 0) precondition(stride > 0 && stride <= windowSize) precondition(reloadInterval > 0, "reloadInterval must be positive (used as modulo divisor)") + let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize + precondition(resolvedMaxInputTokens >= windowSize, + "maxInputTokens must be >= windowSize (got \(resolvedMaxInputTokens) < \(windowSize))") let configuration = MLModelConfiguration() configuration.computeUnits = computeUnits @@ -194,6 +219,8 @@ public actor T5CoreMLEmbedder: Embedder { self.stride = stride self.failureLogURL = failureLogURL self.reloadInterval = reloadInterval + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicy } /// Convenience init that resolves the model URL from a `Bundle`. @@ -209,7 +236,9 @@ public actor T5CoreMLEmbedder: Embedder { stride: Int = 256, minNorm: Float = 1.0, failureLogURL: URL? = nil, - reloadInterval: Int = 500 + reloadInterval: Int = 500, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) async throws { guard let url = bundle.url( forResource: resourceName, @@ -231,7 +260,9 @@ public actor T5CoreMLEmbedder: Embedder { stride: stride, minNorm: minNorm, failureLogURL: failureLogURL, - reloadInterval: reloadInterval + reloadInterval: reloadInterval, + maxInputTokens: maxInputTokens, + overflowPolicy: overflowPolicy ) } @@ -246,12 +277,17 @@ public actor T5CoreMLEmbedder: Embedder { stride: Int = 256, minNorm: Float = 1.0, modelIdentifier: String = "stub@v0", - failureLogURL: URL? = nil + failureLogURL: URL? = nil, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) { precondition(dims > 0 && dims % 2 == 0, "dims must be positive and even") precondition(windowSize > 0) precondition(stride > 0 && stride <= windowSize) + let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize + precondition(resolvedMaxInputTokens >= windowSize, + "maxInputTokens must be >= windowSize") let capturedPredictor = predictor self.predictorFactory = { capturedPredictor } self.cpuPredictorFactory = nil @@ -264,6 +300,8 @@ public actor T5CoreMLEmbedder: Embedder { self.modelIdentifier = modelIdentifier self.failureLogURL = failureLogURL self.reloadInterval = 500 + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicy } /// Test-only init: inject a factory for predictor lifecycle testing. @@ -283,13 +321,18 @@ public actor T5CoreMLEmbedder: Embedder { minNorm: Float = 1.0, modelIdentifier: String = "stub@v0", failureLogURL: URL? = nil, - reloadInterval: Int = 500 + reloadInterval: Int = 500, + maxInputTokens: Int? = nil, + overflowPolicy: EmbedderOverflowPolicy = .truncate ) throws { precondition(dims > 0 && dims % 2 == 0, "dims must be positive and even") precondition(windowSize > 0) precondition(stride > 0 && stride <= windowSize) precondition(reloadInterval > 0, "reloadInterval must be positive (used as modulo divisor)") + let resolvedMaxInputTokens = maxInputTokens ?? 8 * windowSize + precondition(resolvedMaxInputTokens >= windowSize, + "maxInputTokens must be >= windowSize") self.predictorFactory = predictorFactory self.cpuPredictorFactory = cpuPredictorFactory self.predictor = try predictorFactory() @@ -301,6 +344,8 @@ public actor T5CoreMLEmbedder: Embedder { self.modelIdentifier = modelIdentifier self.failureLogURL = failureLogURL self.reloadInterval = reloadInterval + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicy } // MARK: - Embedder @@ -313,8 +358,10 @@ public actor T5CoreMLEmbedder: Embedder { /// failures are silently retried on CPU (see class doc-comment); callers /// only receive an error if the retry also fails. /// - /// - Throws: `T5CoreMLEmbedderError.missingOutput` if the CoreML - /// model does not produce the expected feature dictionary; + /// - Throws: `EmbedderError.inputTooLarge(actual:max:)` when the token + /// count exceeds `maxInputTokens` and `overflowPolicy` is `.reject`; + /// `T5CoreMLEmbedderError.missingOutput` if the CoreML model does not + /// produce the expected feature dictionary; /// `CoreMLNativeError.nativeException` if CoreML raises an internal /// ObjC exception that the embedder cannot recover from; any /// tokenizer-originated error. @@ -339,9 +386,20 @@ public actor T5CoreMLEmbedder: Embedder { let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) if trimmed.isEmpty { return [] } - let tokens = try tokenizer.encode(text, addSpecialTokens: true) + var tokens = try tokenizer.encode(text, addSpecialTokens: true) if tokens.isEmpty { return [] } + // Overflow guard: prevent oversized inputs from generating hundreds of + // sliding windows and exhausting the ANE IOSurface buffer pool (ADR 022). + if tokens.count > maxInputTokens { + switch overflowPolicy { + case .truncate: + tokens = Array(tokens.prefix(maxInputTokens)) + case .reject: + throw EmbedderError.inputTooLarge(actual: tokens.count, max: maxInputTokens) + } + } + // Proactive model reload: recreate the predictor every reloadInterval // encodes to flush accumulated ANE IOSurface resources. // Counter increments only for real inference calls (whitespace-only inputs diff --git a/Sources/SwitchcraftMetal/T5MetalEmbedder.swift b/Sources/SwitchcraftMetal/T5MetalEmbedder.swift index 94afda9..6922025 100644 --- a/Sources/SwitchcraftMetal/T5MetalEmbedder.swift +++ b/Sources/SwitchcraftMetal/T5MetalEmbedder.swift @@ -65,6 +65,8 @@ public actor T5MetalEmbedder: Embedder { public nonisolated let minNorm: Float public nonisolated let windowSize: Int public nonisolated let stride: Int + public nonisolated let maxInputTokens: Int + public nonisolated let overflowPolicy: EmbedderOverflowPolicy // MARK: - Architecture constants // @@ -149,12 +151,17 @@ public actor T5MetalEmbedder: Embedder { windowSize windowSizeParam: Int = 512, stride strideParam: Int = 256, minNorm minNormParam: Float = 1.0, - modelIdentifier modelIdentifierParam: String = "google/xtr-base-en@v1+gguf" + modelIdentifier modelIdentifierParam: String = "google/xtr-base-en@v1+gguf", + maxInputTokens maxInputTokensParam: Int? = nil, + overflowPolicy overflowPolicyParam: EmbedderOverflowPolicy = .truncate ) async throws { precondition(dimsParam > 0 && dimsParam % 2 == 0, "dims must be positive and even (Q4 codec packs two nibbles per byte)") precondition(windowSizeParam > 0) precondition(strideParam > 0 && strideParam <= windowSizeParam) + let resolvedMaxInputTokens = maxInputTokensParam ?? 8 * windowSizeParam + precondition(resolvedMaxInputTokens >= windowSizeParam, + "maxInputTokens must be >= windowSize") guard let context = MetalContext.shared else { throw T5MetalEmbedderError.metalUnavailable @@ -414,6 +421,8 @@ public actor T5MetalEmbedder: Embedder { self.minNorm = minNormParam self.windowSize = windowSizeParam self.stride = strideParam + self.maxInputTokens = resolvedMaxInputTokens + self.overflowPolicy = overflowPolicyParam self.tokenizer = tokenizer self.context = context self.layers = layerWeights @@ -456,9 +465,20 @@ public actor T5MetalEmbedder: Embedder { let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) if trimmed.isEmpty { return [] } - let tokens = try tokenizer.encode(text, addSpecialTokens: true) + var tokens = try tokenizer.encode(text, addSpecialTokens: true) if tokens.isEmpty { return [] } + // Overflow guard: prevent oversized inputs from generating hundreds of + // Metal command buffers and exhausting device memory (ADR 022). + if tokens.count > maxInputTokens { + switch overflowPolicy { + case .truncate: + tokens = Array(tokens.prefix(maxInputTokens)) + case .reject: + throw EmbedderError.inputTooLarge(actual: tokens.count, max: maxInputTokens) + } + } + let starts = SlidingWindow.plan( tokenCount: tokens.count, windowSize: windowSize, diff --git a/Tests/SwitchcraftTests/SearchTimeoutTests.swift b/Tests/SwitchcraftTests/SearchTimeoutTests.swift index fa57a0a..f82763a 100644 --- a/Tests/SwitchcraftTests/SearchTimeoutTests.swift +++ b/Tests/SwitchcraftTests/SearchTimeoutTests.swift @@ -55,6 +55,7 @@ struct SearchTimeoutTests { var dims: Int { inner.dims } var modelIdentifier: String { inner.modelIdentifier } + var maxInputTokens: Int { inner.maxInputTokens } func encode(_ text: String) async throws -> [Float] { // Task.sleep respects task cancellation: it throws diff --git a/Tests/SwitchcraftTests/Support/MockEmbedder.swift b/Tests/SwitchcraftTests/Support/MockEmbedder.swift index fd15c58..f7f4a6d 100644 --- a/Tests/SwitchcraftTests/Support/MockEmbedder.swift +++ b/Tests/SwitchcraftTests/Support/MockEmbedder.swift @@ -15,6 +15,7 @@ import SwitchcraftCore struct MockEmbedder: Embedder, Sendable { let dims: Int let modelIdentifier: String + let maxInputTokens: Int = Int.max init(dims: Int = 128, modelIdentifier: String? = nil) { precondition(dims > 0 && dims % 2 == 0, diff --git a/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift b/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift index 4f24545..aba9e62 100644 --- a/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift +++ b/Tests/SwitchcraftTests/SwitchcraftStoreTests.swift @@ -339,6 +339,7 @@ struct SwitchcraftStoreTests { struct OddDimsEmbedder: Embedder { let dims = 33 let modelIdentifier = "odd" + let maxInputTokens: Int = Int.max func encode(_ text: String) async throws -> [Float] { [] } } diff --git a/Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift b/Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift new file mode 100644 index 0000000..0af1a1d --- /dev/null +++ b/Tests/SwitchcraftTests/T5CoreMLEmbedderOverflowTests.swift @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: Apache-2.0 +import Foundation +import Testing +import SwitchcraftCore +@testable import SwitchcraftCoreML + +#if canImport(CoreML) +import CoreML + +/// Tests for `T5CoreMLEmbedder`'s overflow guard (`maxInputTokens` / `overflowPolicy`). +/// +/// No CoreML model asset is required — `CountingStubPredictor` is injected +/// via the factory-based internal init. All tests use a small `windowSize` and +/// `dims` to keep per-iteration allocations tiny. +@Suite("T5CoreMLEmbedder overflow guard") +struct T5CoreMLEmbedderOverflowTests { + + // MARK: - Helpers + + private static func makeTokenizer() throws -> Tokenizer { + let url = try #require( + Bundle.module.url( + forResource: "xtr-base-en.tokenizer", + withExtension: "json", + subdirectory: "Fixtures" + ), + "tokenizer fixture missing from test bundle" + ) + return try Tokenizer(contentsOf: url.path) + } + + /// Returns a string that reliably tokenises to more than `minTokens` tokens. + /// Repeats "hello " enough times that the BPE tokenizer (plus the T5 + /// sentinel) always produces more than `minTokens` tokens. + private static func oversizedInput(minTokens: Int) -> String { + // "hello" is a single BPE token in xtr-base-en; " " is a word-start + // prefix on the following token. Repeating "hello " minTokens+5 times + // reliably produces more than minTokens tokens after tokenisation. + return String(repeating: "hello ", count: minTokens + 5) + } + + // MARK: - R7d: maxInputTokens property + + @Test("maxInputTokens property returns the configured value") + func testMaxInputTokensPropertyReturnsConfiguredValue() throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let customLimit = 24 // >= windowSize + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: customLimit + ) + + #expect(embedder.maxInputTokens == customLimit) + } + + // MARK: - R7a: truncate policy + + @Test("Truncate policy encodes oversized input without error and returns non-empty embeddings") + func testTruncatePolicyEncodesOversizedInputWithoutError() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 // >= windowSize; guarantees at least 1 full window + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .truncate + ) + + // Build an input that definitely exceeds maxTokens. + let input = Self.oversizedInput(minTokens: maxTokens) + let result = try await embedder.encode(input) + + // Truncation succeeds: the result must be non-empty and at most + // maxTokens rows wide (each row is `dims` floats). + #expect(!result.isEmpty, "truncated encode must return non-empty embeddings") + #expect(result.count <= maxTokens * dims, + "result has \(result.count) floats; expected ≤ \(maxTokens * dims) for \(maxTokens) tokens × \(dims) dims") + } + + // MARK: - R7b: reject policy + + @Test("Reject policy throws EmbedderError.inputTooLarge with correct actual and max values") + func testRejectPolicyThrowsInputTooLargeError() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .reject + ) + + // Count the tokens the tokenizer actually produces so we can verify + // the `actual` field on the thrown error. + let input = Self.oversizedInput(minTokens: maxTokens) + let actualTokenCount = try tokenizer.encode(input, addSpecialTokens: true).count + try #require(actualTokenCount > maxTokens, + "test precondition: input must tokenise to > \(maxTokens) tokens; got \(actualTokenCount)") + + await #expect(throws: EmbedderError.inputTooLarge(actual: actualTokenCount, max: maxTokens)) { + _ = try await embedder.encode(input) + } + } + + // MARK: - R5, R7c: stress test + + @Test("1,000 encode calls with oversized inputs interleaved complete without ANE pool degradation (truncate policy)") + func testOverflowStress1000CallsTruncatePolicy() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .truncate + ) + + let normalInput = "semantic search" + let oversizedInput = Self.oversizedInput(minTokens: maxTokens) + + for i in 0..<1_000 { + let text = (i == 250 || i == 750) ? oversizedInput : normalInput + // All calls must succeed (oversized inputs are silently truncated). + let result = try await embedder.encode(text) + #expect(!result.isEmpty, "encode returned empty at index \(i)") + } + } + + @Test("1,000 encode calls: reject policy throws on oversized inputs but normal-sized calls always succeed") + func testOverflowStress1000CallsRejectPolicy() async throws { + let tokenizer = try Self.makeTokenizer() + let dims = 16 + let windowSize = 8 + let maxTokens = 16 + + let embedder = try T5CoreMLEmbedder( + predictorFactory: { CountingStubPredictor(dims: dims) }, + tokenizer: tokenizer, + dims: dims, + windowSize: windowSize, + stride: 4, + minNorm: 1.0, + maxInputTokens: maxTokens, + overflowPolicy: .reject + ) + + let normalInput = "semantic search" + let oversizedInput = Self.oversizedInput(minTokens: maxTokens) + + var oversizedErrors = 0 + for i in 0..<1_000 { + if i == 250 || i == 750 { + // Must throw inputTooLarge — and must NOT leave the embedder in a + // broken state (subsequent normal calls must still work). + do { + _ = try await embedder.encode(oversizedInput) + Issue.record("Expected EmbedderError.inputTooLarge at index \(i) but no error was thrown") + } catch let err as EmbedderError { + if case .inputTooLarge = err { + oversizedErrors += 1 + } else { + Issue.record("Unexpected EmbedderError at index \(i): \(err)") + } + } catch { + Issue.record("Unexpected error type at index \(i): \(error)") + } + } else { + // Normal-sized call must always succeed. + let result = try await embedder.encode(normalInput) + #expect(!result.isEmpty, "normal encode returned empty at index \(i)") + } + } + + #expect(oversizedErrors == 2, "expected 2 inputTooLarge errors (indices 250 and 750), got \(oversizedErrors)") + } +} + +#endif From c3857a01fd466a07ee557d317a8848f42caf3a63 Mon Sep 17 00:00:00 2001 From: Shannon Holland Date: Wed, 6 May 2026 08:22:42 -0700 Subject: [PATCH 2/5] docs: add ADR 022 (embedder overflow guard) and update Plan.md Documents the design decisions from issue #89: why maxInputTokens belongs on the Embedder protocol, the 8 * windowSize default rationale, why EmbedderOverflowPolicy/EmbedderError live in SwitchcraftCore, and why T5MetalEmbedder gets the guard without stub-based tests. Co-Authored-By: Claude Sonnet 4.6 --- adrs/022-embedder-overflow-guard.md | 94 +++++++++++++++++++++++++++++ docs/Plan.md | 1 + 2 files changed, 95 insertions(+) create mode 100644 adrs/022-embedder-overflow-guard.md diff --git a/adrs/022-embedder-overflow-guard.md b/adrs/022-embedder-overflow-guard.md new file mode 100644 index 0000000..3def3a2 --- /dev/null +++ b/adrs/022-embedder-overflow-guard.md @@ -0,0 +1,94 @@ +# ADR 022: Embedder overflow guard — `maxInputTokens` and `EmbedderOverflowPolicy` + +**Status:** Accepted +**Date:** 2026-05-06 +**Issue:** [#89](https://github.com/totalslacker/switchcraft/issues/89) + +## Context + +`T5CoreMLEmbedder` uses a sliding-window strategy (ADR 011) to handle inputs longer than +the CoreML model's fixed 512-token input. For a 593,285-character page that tokenises to +~148,000 tokens, `SlidingWindow.plan` generates approximately 577 windows. 577 consecutive +CoreML predictions within a single `encode()` call exhausted the ANE IOSurface buffer +pool, leaving the embedder in a state where every subsequent inference — including +small inputs — failed with `MLE5OutputPortBinder bindAndReturnError`. Process restart +was required to recover. + +ADR 021 adds three reactive/proactive defence layers (autoreleasepool drainage, proactive +model reload, CPU fallback). This ADR adds the structural prevention layer: refuse to +generate a window burst this large in the first place. + +## Decision + +### 1. `EmbedderOverflowPolicy` and `EmbedderError` live in `SwitchcraftCore` + +Both types are pure Swift with no CoreML or Metal dependency. Placing them in +`SwitchcraftCore` (not `SwitchcraftCoreML`) makes them: + +- Available to any `Embedder` conformer, including `T5MetalEmbedder` and future backends. +- Reachable by callers who hold an `any Embedder` reference without downcasting. +- Not gated behind `#if canImport(CoreML)`. + +### 2. `maxInputTokens: Int { get }` is added to the `Embedder` protocol + +Adding the limit to the protocol rather than only to concrete types: + +- Makes the safety contract queryable on any `Embedder`-typed reference. +- Forces new conformers to declare an explicit limit — an unknown embedder with no + stated limit is unsafe for bulk-index consumers. +- Follows the existing pattern where `dims` and `modelIdentifier` are protocol-level. + +Conformers are encouraged (but not required by the compiler) to implement this as a +`nonisolated let` stored property so callers can read it without entering an actor. + +### 3. Default `maxInputTokens = 8 * windowSize` + +For the standard 512-token window, this is 4,096 total tokens → ~15 windows at stride 256. +The multiplier form scales naturally when `windowSize` is customised at init. 4,096 covers +most real-world documents (~3,000 words) while keeping the window count well below the +threshold observed to exhaust the ANE pool (~577). + +Callers running in environments with larger IOSurface budgets (or no ANE) may pass a larger +value at init. A `precondition(maxInputTokens >= windowSize)` guards against setting the +limit below one window, which would make every non-trivial encode either truncate to +padding or always throw. + +### 4. Two policies: `.truncate` (default) and `.reject` + +`.truncate` is the default for three reasons: +- It matches the established HuggingFace convention (`truncation=True, max_length=...`). +- The primary use case (retrieval, classification, bulk indexing) benefits from prefix + embeddings when the full document is too long. +- It preserves backward compatibility — existing callers that do not pass `overflowPolicy` + get silent truncation rather than a new throw path. + +`.reject` is provided for callers that cannot accept silent data loss (e.g., applications +that want to split long documents themselves and know exactly which text was embedded). + +### 5. `T5MetalEmbedder` gains the same guard + +`T5MetalEmbedder` uses `SlidingWindow` identically to `T5CoreMLEmbedder`. While it lacks +the IOSurface failure mode, a 577-Metal-command-buffer burst from one `encode` call is both +slow and memory-intensive. With `maxInputTokens` now on the protocol, adding the guard to +the Metal embedder is both consistent and low-risk. No stub-based tests are added for the +Metal path because `T5MetalEmbedder.init` requires a real Metal device and GGUF asset; +the CoreML stub tests provide adequate algorithmic coverage of the overflow guard logic. + +## Guard placement + +The overflow guard is inserted in `encode(_:)` after `tokenizer.encode(text, addSpecialTokens: true)` +and before `SlidingWindow.plan`. This is the earliest point where the authoritative token +count is known. It fires inside the re-entrancy guard's `inFlight` window, so the `defer` +block that releases the next waiter still executes correctly when `.reject` throws. + +## Consequences + +- **Breaking change:** `Embedder` protocol gains a required property `maxInputTokens: Int { get }`. + All existing conformers must implement it. Test mocks and local test stubs return `Int.max`. +- `T5CoreMLEmbedder` and `T5MetalEmbedder` gain two new init parameters + (`maxInputTokens: Int? = nil`, `overflowPolicy: EmbedderOverflowPolicy = .truncate`) + with defaults that maintain backward compatibility. +- `EmbedderError` is a new public enum in `SwitchcraftCore`. It will accumulate future + embedder-level errors (not model-specific errors, which remain on `T5CoreMLEmbedderError`). +- Bulk-index consumers indexing very large documents will silently receive prefix embeddings + unless they explicitly configure `.reject`. This is documented on the `encode` method. diff --git a/docs/Plan.md b/docs/Plan.md index f76afae..4b38e40 100644 --- a/docs/Plan.md +++ b/docs/Plan.md @@ -318,6 +318,7 @@ Track progress by checking off items as they land. Effort estimates and notes fo - [x] **T5 encoder → CoreML** (1-2 weeks) — Model conversion + Swift wrapper - [x] **T5CoreMLEmbedder crash-safety** (#78) — `MLExceptionCatcher` ObjC `@try/@catch` bridge (separate `SwitchcraftCoreMLObjC` clang target per ADR 018); `catchingNSException` Swift facade converts CoreML internal `NSException`s to `CoreMLNativeError.nativeException` typed errors; `MLPredictor` internal protocol DI seam; re-entrancy guard (`inFlight` + `waiters` continuations); optional `failureLogURL` for JSONL crash telemetry + `os.Logger` logging - [x] **T5CoreMLEmbedder ANE IOSurface fix** (#87) — `autoreleasepool` per window, proactive model reload every `reloadInterval` encodes (default 500, tunable), reactive CPU-fallback with JSONL recovery telemetry (`"recovered_iosurface_exhaustion"`); stub stress test (5k iterations, always-on CI) + real-asset stress test (10k iterations, asset-gated); ADR 021 +- [x] **Embedder overflow guard** (#89) — `maxInputTokens: Int` added to `Embedder` protocol; `EmbedderOverflowPolicy` (`.truncate` / `.reject`) + `EmbedderError.inputTooLarge` in `SwitchcraftCore`; overflow guard in `T5CoreMLEmbedder.encode(_:)` and `T5MetalEmbedder.encode(_:)` between tokenization and `SlidingWindow.plan`; default `8 * windowSize` (4,096 tokens → ~15 windows); prevents ANE pool poisoning from oversized inputs; ADR 022 - [x] **K-means clustering** (1 week) — Standard algorithm, use Accelerate - [x] **4-bit residual codec** (1 week) — ~200 lines, bit-level packing; round-trip property tests - [x] **LSM-tree index structure** (1 week) — Cascading merge logic From e2b488e45d4e1da94623fb46df032e2a5b4f747d Mon Sep 17 00:00:00 2001 From: Shannon Holland Date: Wed, 6 May 2026 09:34:47 -0700 Subject: [PATCH 3/5] fix(ci): use @unchecked Sendable for retroactive MLModel conformance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Swift 6 (Xcode 16.4 / CI) rejects bare retroactive Sendable conformances on types defined in other modules. MLModel is defined in CoreML, so `extension MLModel: MLPredictor` implicitly conferred Sendable — illegal in Swift 6. Adding `extension MLModel: @unchecked Sendable {}` satisfies the compiler. Safe: all MLModel access is gated through T5CoreMLEmbedder's actor isolation. Fixes CI regression introduced when the CI runner upgraded to Xcode 16.4. Co-Authored-By: Claude Sonnet 4.6 --- Sources/SwitchcraftCoreML/MLPredictor.swift | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Sources/SwitchcraftCoreML/MLPredictor.swift b/Sources/SwitchcraftCoreML/MLPredictor.swift index 977593b..fd2c787 100644 --- a/Sources/SwitchcraftCoreML/MLPredictor.swift +++ b/Sources/SwitchcraftCoreML/MLPredictor.swift @@ -16,6 +16,10 @@ internal protocol MLPredictor: Sendable { func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider } +// @unchecked: retroactive conformance (MLModel is from CoreML); safe because +// all access is gated through T5CoreMLEmbedder's actor isolation. +extension MLModel: @unchecked Sendable {} + extension MLModel: MLPredictor { internal func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider { try self.prediction(from: input) From a155c5da27e3dd3f74989abbea0e9aa0d562e027 Mon Sep 17 00:00:00 2001 From: Shannon Holland Date: Wed, 6 May 2026 10:51:02 -0700 Subject: [PATCH 4/5] fix(ci): resolve 'sending provider' Swift 6 data-race error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Swift 6 (Xcode 16.4) treats closures passed to ObjC blocks (including autoreleasepool + MLExceptionCatcher.perform) as region-boundary crossings. Passing a non-Sendable `any MLFeatureProvider` across that boundary triggers "sending 'provider' risks causing data races" at T5CoreMLEmbedder.swift:506,526. Fix: add retroactive @unchecked Sendable to MLDictionaryFeatureProvider (safe: always accessed within T5CoreMLEmbedder's actor isolation), and narrow predictWindow's parameter type from `any MLFeatureProvider` to the concrete MLDictionaryFeatureProvider — the only type ever passed at the single call site. Co-Authored-By: Claude Sonnet 4.6 --- Sources/SwitchcraftCoreML/MLPredictor.swift | 5 +++-- Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Sources/SwitchcraftCoreML/MLPredictor.swift b/Sources/SwitchcraftCoreML/MLPredictor.swift index fd2c787..12cca85 100644 --- a/Sources/SwitchcraftCoreML/MLPredictor.swift +++ b/Sources/SwitchcraftCoreML/MLPredictor.swift @@ -16,9 +16,10 @@ internal protocol MLPredictor: Sendable { func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider } -// @unchecked: retroactive conformance (MLModel is from CoreML); safe because -// all access is gated through T5CoreMLEmbedder's actor isolation. +// @unchecked: retroactive conformances (types from CoreML); safe because all +// access is gated through T5CoreMLEmbedder's actor isolation. extension MLModel: @unchecked Sendable {} +extension MLDictionaryFeatureProvider: @unchecked Sendable {} extension MLModel: MLPredictor { internal func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider { diff --git a/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift b/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift index 1ca2cb0..351cdeb 100644 --- a/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift +++ b/Sources/SwitchcraftCoreML/T5CoreMLEmbedder.swift @@ -496,7 +496,7 @@ public actor T5CoreMLEmbedder: Embedder { /// Run one window prediction with autoreleasepool drainage, reactive reload, /// ANE retry, and IOSurface CPU fallback. private func predictWindow( - provider: any MLFeatureProvider, + provider: MLDictionaryFeatureProvider, inputLength: Int, windowTokenCount: Int ) throws -> any MLFeatureProvider { From 61e23973750155af23723d94bd9eca50606a254a Mon Sep 17 00:00:00 2001 From: Shannon Holland Date: Wed, 6 May 2026 10:53:06 -0700 Subject: [PATCH 5/5] fix(ci): add @retroactive to silence retroactive Sendable conformance warnings Swift recommends @retroactive for conformances of external types to external protocols to avoid future conflicts if CoreML adds the conformance itself. Co-Authored-By: Claude Sonnet 4.6 --- Sources/SwitchcraftCoreML/MLPredictor.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/SwitchcraftCoreML/MLPredictor.swift b/Sources/SwitchcraftCoreML/MLPredictor.swift index 12cca85..a55b001 100644 --- a/Sources/SwitchcraftCoreML/MLPredictor.swift +++ b/Sources/SwitchcraftCoreML/MLPredictor.swift @@ -16,10 +16,10 @@ internal protocol MLPredictor: Sendable { func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider } -// @unchecked: retroactive conformances (types from CoreML); safe because all -// access is gated through T5CoreMLEmbedder's actor isolation. -extension MLModel: @unchecked Sendable {} -extension MLDictionaryFeatureProvider: @unchecked Sendable {} +// @unchecked @retroactive: retroactive conformances on CoreML types; safe +// because all access is gated through T5CoreMLEmbedder's actor isolation. +extension MLModel: @unchecked @retroactive Sendable {} +extension MLDictionaryFeatureProvider: @unchecked @retroactive Sendable {} extension MLModel: MLPredictor { internal func predict(input: any MLFeatureProvider) throws -> any MLFeatureProvider {