-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathwire_roundtrip.py
More file actions
59 lines (47 loc) · 1.9 KB
/
wire_roundtrip.py
File metadata and controls
59 lines (47 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Stable-root smoke example for Semafold's exact passthrough path."""
from __future__ import annotations
import numpy as np
from semafold import PassthroughVectorCodec
from semafold import VectorDecodeRequest
from semafold import VectorEncodeRequest
from semafold import VectorEncoding
from semafold.vector.models import EncodeObjective
def _format_bytes(value: int) -> str:
return f"{value:,} B"
def _format_summary(*, codec_family: str, variant_id: str, baseline_bytes: int, artifact_bytes: int) -> str:
delta = artifact_bytes - baseline_bytes
return "\n".join(
[
"Semafold wire roundtrip",
f"codec: {codec_family}/{variant_id}",
f"baseline bytes: {_format_bytes(baseline_bytes)}",
f"artifact bytes: {_format_bytes(artifact_bytes)}",
f"bytes delta: {delta:+,d} B",
]
)
def main() -> None:
"""Run a lossless wire-roundtrip smoke test against the stable root surface."""
codec = PassthroughVectorCodec()
request = VectorEncodeRequest(
data=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32),
objective=EncodeObjective.RECONSTRUCTION,
role="embedding",
profile_id="examples.wire_roundtrip",
)
encoded = codec.encode(request)
encoding = VectorEncoding.from_dict(encoded.to_dict())
decoded = codec.decode(VectorDecodeRequest(encoding=encoding))
if not np.array_equal(decoded.data, request.data):
raise SystemExit("round-trip mismatch")
print(
_format_summary(
codec_family=encoding.codec_family,
variant_id=encoding.variant_id,
baseline_bytes=int(request.data.nbytes),
artifact_bytes=int(encoding.footprint.total_bytes),
)
)
print("lossless roundtrip: yes")
print(f"segment kinds: {[segment.segment_kind for segment in encoding.segments]}")
if __name__ == "__main__":
main()