Skip to content

Commit b0d3374

Browse files
author
Engin Mahmut
committed
Add enums for encode objectives/metrics/segments
Introduce typed enums for encoding metadata and wire-safe coercion: EncodeObjective, EncodeMetric, EncodingSegmentKind, EncodingBoundType, and WorkloadSuitability. Update core and vector models to use these enums (with coercion helpers), change codec implementations and helper modules to use the enums instead of raw strings, and ensure segment_kind/guarantee fields serialize as enum values. Export the new symbols from package __init__ and update examples and tests to use the enums. Also bump package version to 0.2.0, adjust a few robustness issues (mlx import lint fix, pyproject cleanup), and harden tests (git executable check, exclude venv paths when scanning for bytecode).
1 parent 4c9d10f commit b0d3374

51 files changed

Lines changed: 454 additions & 233 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,18 @@ Stable today:
9898
- `CompressionFootprint`
9999
- `CompressionGuarantee`
100100
- `ValidationEvidence`
101+
- `EncodingBoundType`
102+
- `WorkloadSuitability`
101103
- `VectorEncodeRequest`
102104
- `VectorEncodingSegment`
103105
- `VectorEncoding`
104106
- `VectorDecodeRequest`
105107
- `VectorDecodeResult`
106108
- `VectorCodec`
107109
- `PassthroughVectorCodec`
110+
- `EncodeObjective`
111+
- `EncodeMetric`
112+
- `EncodingSegmentKind`
108113

109114
Available today, but intentionally outside the stable root surface:
110115
- `semafold.turboquant`
@@ -139,14 +144,15 @@ Run the exact file here: [examples/wire_roundtrip.py](examples/wire_roundtrip.py
139144
```python
140145
import numpy as np
141146

147+
from semafold import EncodeObjective
142148
from semafold import PassthroughVectorCodec
143149
from semafold import VectorDecodeRequest
144150
from semafold import VectorEncodeRequest
145151

146152
codec = PassthroughVectorCodec()
147153
request = VectorEncodeRequest(
148154
data=np.linspace(-1.0, 1.0, 1024, dtype=np.float32),
149-
objective="reconstruction",
155+
objective=EncodeObjective.RECONSTRUCTION,
150156
)
151157

152158
encoding = codec.encode(request)
@@ -162,6 +168,8 @@ Run the exact file here: [examples/turboquant_embedding.py](examples/turboquant_
162168
```python
163169
import numpy as np
164170

171+
from semafold import EncodeMetric
172+
from semafold import EncodeObjective
165173
from semafold import VectorDecodeRequest
166174
from semafold import VectorEncodeRequest
167175
from semafold.turboquant import TurboQuantMSEConfig
@@ -175,8 +183,8 @@ codec = TurboQuantMSEVectorCodec(
175183
encoding = codec.encode(
176184
VectorEncodeRequest(
177185
data=rows,
178-
objective="reconstruction",
179-
metric="mse",
186+
objective=EncodeObjective.RECONSTRUCTION,
187+
metric=EncodeMetric.MSE,
180188
role="embedding",
181189
seed=11,
182190
)

STABILITY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ Stable surface:
1212
- `CompressionFootprint`
1313
- `CompressionGuarantee`
1414
- `ValidationEvidence`
15+
- `EncodingBoundType`
16+
- `WorkloadSuitability`
1517
- `VectorEncodeRequest`
1618
- `VectorEncodingSegment`
1719
- `VectorEncoding`
1820
- `VectorDecodeRequest`
1921
- `VectorDecodeResult`
2022
- `VectorCodec`
2123
- `PassthroughVectorCodec`
24+
- `EncodeObjective`
25+
- `EncodeMetric`
26+
- `EncodingSegmentKind`
2227

2328
## Intentionally Not Stable
2429

benchmarks/turboquant_paper_validation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from semafold import __version__
1212
from semafold import VectorDecodeRequest, VectorEncodeRequest
13+
from semafold.vector.models import EncodeObjective, EncodeMetric
1314
from semafold.turboquant import (
1415
TurboQuantMSEConfig,
1516
TurboQuantMSEVectorCodec,
@@ -44,7 +45,7 @@ def _mse_record(*, rows: np.ndarray, bits_per_scalar: int, rotation_seed: int) -
4445
default_rotation_seed=rotation_seed,
4546
)
4647
)
47-
request = VectorEncodeRequest(data=rows, objective="reconstruction", metric="mse")
48+
request = VectorEncodeRequest(data=rows, objective=EncodeObjective.RECONSTRUCTION, metric=EncodeMetric.MSE)
4849

4950
encode_start = time.perf_counter()
5051
encoding = codec.encode(request)
@@ -87,8 +88,8 @@ def _prod_record(
8788
)
8889
request = VectorEncodeRequest(
8990
data=rows,
90-
objective="inner_product_estimation",
91-
metric="dot_product_error",
91+
objective=EncodeObjective.INNER_PRODUCT_ESTIMATION,
92+
metric=EncodeMetric.DOT_PRODUCT_ERROR,
9293
)
9394

9495
encode_start = time.perf_counter()

benchmarks/turboquant_synthetic_kv_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from semafold import __version__
1212
from semafold import VectorDecodeRequest, VectorEncodeRequest
13+
from semafold.vector.models import EncodeObjective, EncodeMetric
1314
from semafold.turboquant import (
1415
TurboQuantMSEConfig,
1516
TurboQuantMSEVectorCodec,
@@ -101,8 +102,8 @@ def run_synthetic_kv_benchmark(
101102
)
102103
key_request = VectorEncodeRequest(
103104
data=key_rows,
104-
objective="inner_product_estimation",
105-
metric="dot_product_error",
105+
objective=EncodeObjective.INNER_PRODUCT_ESTIMATION,
106+
metric=EncodeMetric.DOT_PRODUCT_ERROR,
106107
role="key_cache",
107108
)
108109
key_encode_start = time.perf_counter()
@@ -124,8 +125,8 @@ def run_synthetic_kv_benchmark(
124125
)
125126
value_request = VectorEncodeRequest(
126127
data=value_rows,
127-
objective="reconstruction",
128-
metric="mse",
128+
objective=EncodeObjective.RECONSTRUCTION,
129+
metric=EncodeMetric.MSE,
129130
role="value_cache",
130131
)
131132
value_encode_start = time.perf_counter()

examples/turboquant_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from semafold.turboquant import TurboQuantMSEConfig
1010
from semafold.turboquant import TurboQuantMSEVectorCodec
1111
from semafold.turboquant.backends import get_backend
12+
from semafold.vector.models import EncodeObjective, EncodeMetric
1213

1314

1415
def _row_cosine_similarity(lhs: np.ndarray, rhs: np.ndarray) -> float:
@@ -54,8 +55,8 @@ def main() -> None:
5455
encoding = codec.encode(
5556
VectorEncodeRequest(
5657
data=rows,
57-
objective="reconstruction",
58-
metric="mse",
58+
objective=EncodeObjective.RECONSTRUCTION,
59+
metric=EncodeMetric.MSE,
5960
role="embedding",
6061
seed=11,
6162
)

examples/wire_roundtrip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from semafold import VectorDecodeRequest
99
from semafold import VectorEncodeRequest
1010
from semafold import VectorEncoding
11+
from semafold.vector.models import EncodeObjective
1112

1213

1314
def _format_bytes(value: int) -> str:
@@ -33,7 +34,7 @@ def main() -> None:
3334
codec = PassthroughVectorCodec()
3435
request = VectorEncodeRequest(
3536
data=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32),
36-
objective="reconstruction",
37+
objective=EncodeObjective.RECONSTRUCTION,
3738
role="embedding",
3839
profile_id="examples.wire_roundtrip",
3940
)

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,4 @@ include = ["src", "tests", "examples", "benchmarks"]
4242
typeCheckingMode = "basic"
4343
venvPath = "."
4444
venv = ".venv"
45-
# Optional backends (torch, mlx) are not installed in the base dev environment.
46-
# reportMissingImports is silenced so pyright does not error on those packages;
47-
# structural type errors in those files are still caught.
48-
reportMissingImports = "none"
45+

src/semafold/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CompressionGuarantee,
99
ValidationEvidence,
1010
)
11+
from semafold.core.models import EncodingBoundType, WorkloadSuitability
1112
from semafold.vector import (
1213
PassthroughVectorCodec,
1314
VectorCodec,
@@ -17,13 +18,18 @@
1718
VectorEncoding,
1819
VectorEncodingSegment,
1920
)
21+
from semafold.vector.models import EncodeMetric, EncodeObjective, EncodingSegmentKind
2022

2123
__all__ = [
2224
"__version__",
2325
"CompressionBudget",
2426
"CompressionEstimate",
2527
"CompressionFootprint",
2628
"CompressionGuarantee",
29+
"EncodeMetric",
30+
"EncodeObjective",
31+
"EncodingBoundType",
32+
"EncodingSegmentKind",
2733
"PassthroughVectorCodec",
2834
"ValidationEvidence",
2935
"VectorCodec",
@@ -32,4 +38,5 @@
3238
"VectorEncodeRequest",
3339
"VectorEncoding",
3440
"VectorEncodingSegment",
41+
"WorkloadSuitability",
3542
]

src/semafold/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
__all__ = ["__version__"]
44

5-
__version__ = "0.1.0"
5+
__version__ = "0.2.0"

src/semafold/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from .accounting import CompressionEstimate, CompressionFootprint
66
from .evidence import ValidationEvidence
7-
from .models import CompressionBudget, CompressionGuarantee
7+
from .models import CompressionBudget, CompressionGuarantee, EncodingBoundType, WorkloadSuitability
88

99
__all__ = [
1010
"CompressionBudget",
1111
"CompressionEstimate",
1212
"CompressionFootprint",
1313
"CompressionGuarantee",
14+
"EncodingBoundType",
1415
"ValidationEvidence",
16+
"WorkloadSuitability",
1517
]

0 commit comments

Comments
 (0)