Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 49 additions & 70 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdDictionary.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,38 +96,18 @@ public static ZstdDictionary of(byte[] raw) {
/// @throws ZstdException if training fails (commonly: not enough sample data)
public static ZstdDictionary train(List<byte[]> samples, int maxDictBytes) {
Objects.requireNonNull(samples, SAMPLES);
if (samples.isEmpty()) {
throw new ZstdException("cannot train a dictionary from zero samples");
}
long total = 0;
for (byte[] s : samples) {
total += s.length;
}
requireNonEmpty(samples, "train");
try (Arena arena = Arena.ofConfined()) {
// flatten all samples into one buffer + a parallel size_t[] of lengths
MemorySegment flat = arena.allocate(Math.max(total, 1));
MemorySegment sizes = arena.allocate(JAVA_LONG, samples.size());
long offset = 0;
for (int i = 0; i < samples.size(); i++) {
byte[] s = samples.get(i);
MemorySegment.copy(s, 0, flat, JAVA_BYTE, offset, s.length);
sizes.setAtIndex(JAVA_LONG, i, s.length);
offset += s.length;
}
FlatSamples in = flatten(arena, samples);
MemorySegment dictBuf = arena.allocate(maxDictBytes);
long produced;
try {
produced = (long) Bindings.ZDICT_TRAIN.invokeExact(
dictBuf, (long) maxDictBytes, flat, sizes, samples.size());
dictBuf, (long) maxDictBytes, in.data(), in.sizes(), in.count());
} catch (Throwable t) {
throw NativeCall.rethrow(t);
}
if (zdictIsError(produced)) {
throw new ZstdException("dictionary training failed: " + zdictErrorName(produced));
}
byte[] out = new byte[Math.toIntExact(produced)];
MemorySegment.copy(dictBuf, JAVA_BYTE, 0, out, 0, out.length);
return new ZstdDictionary(out);
return toDictionary(dictBuf, produced, "dictionary training");
}
}

Expand Down Expand Up @@ -180,23 +160,9 @@ public static ZstdDictionary trainFastCover(List<byte[]> samples, int maxDictByt
private static ZstdDictionary optimize(List<byte[]> samples, int maxDictBytes,
int compressionLevel, boolean fast) {
Objects.requireNonNull(samples, SAMPLES);
if (samples.isEmpty()) {
throw new ZstdException("cannot train a dictionary from zero samples");
}
requireNonEmpty(samples, "train");
try (Arena arena = Arena.ofConfined()) {
long total = 0;
for (byte[] s : samples) {
total += s.length;
}
MemorySegment flat = arena.allocate(Math.max(total, 1));
MemorySegment sizes = arena.allocate(JAVA_LONG, samples.size());
long offset = 0;
for (int i = 0; i < samples.size(); i++) {
byte[] s = samples.get(i);
MemorySegment.copy(s, 0, flat, JAVA_BYTE, offset, s.length);
sizes.setAtIndex(JAVA_LONG, i, s.length);
offset += s.length;
}
FlatSamples in = flatten(arena, samples);
// zeroed params (auto-tune k/d/steps); set single-threaded + target level.
MemoryLayout layout = fast ? FASTCOVER_PARAMS : COVER_PARAMS;
MemorySegment params = arena.allocate(layout);
Expand All @@ -207,16 +173,11 @@ private static ZstdDictionary optimize(List<byte[]> samples, int maxDictBytes,
long produced;
try {
produced = (long) handle.invokeExact(
dictBuf, (long) maxDictBytes, flat, sizes, samples.size(), params);
dictBuf, (long) maxDictBytes, in.data(), in.sizes(), in.count(), params);
} catch (Throwable t) {
throw NativeCall.rethrow(t);
}
if (zdictIsError(produced)) {
throw new ZstdException("dictionary training failed: " + zdictErrorName(produced));
}
byte[] out = new byte[Math.toIntExact(produced)];
MemorySegment.copy(dictBuf, JAVA_BYTE, 0, out, 0, out.length);
return new ZstdDictionary(out);
return toDictionary(dictBuf, produced, "dictionary training");
}
}

Expand All @@ -235,23 +196,9 @@ public static ZstdDictionary finalizeFrom(byte[] content, List<byte[]> samples,
int maxDictBytes, int compressionLevel) {
Objects.requireNonNull(content, "content");
Objects.requireNonNull(samples, SAMPLES);
if (samples.isEmpty()) {
throw new ZstdException("cannot finalise a dictionary from zero samples");
}
requireNonEmpty(samples, "finalise");
try (Arena arena = Arena.ofConfined()) {
long total = 0;
for (byte[] s : samples) {
total += s.length;
}
MemorySegment flat = arena.allocate(Math.max(total, 1));
MemorySegment sizes = arena.allocate(JAVA_LONG, samples.size());
long offset = 0;
for (int i = 0; i < samples.size(); i++) {
byte[] s = samples.get(i);
MemorySegment.copy(s, 0, flat, JAVA_BYTE, offset, s.length);
sizes.setAtIndex(JAVA_LONG, i, s.length);
offset += s.length;
}
FlatSamples in = flatten(arena, samples);
MemorySegment contentSeg = Zstd.copyIn(arena, content);
MemorySegment params = arena.allocate(Bindings.ZDICT_PARAMS_LAYOUT);
params.set(JAVA_INT, 0, compressionLevel); // compressionLevel; notificationLevel/dictID = 0
Expand All @@ -260,17 +207,49 @@ public static ZstdDictionary finalizeFrom(byte[] content, List<byte[]> samples,
try {
produced = (long) Bindings.ZDICT_FINALIZE_DICTIONARY.invokeExact(
dictBuf, (long) maxDictBytes, contentSeg, (long) content.length,
flat, sizes, samples.size(), params);
in.data(), in.sizes(), in.count(), params);
} catch (Throwable t) {
throw NativeCall.rethrow(t);
}
if (zdictIsError(produced)) {
throw new ZstdException("dictionary finalisation failed: " + zdictErrorName(produced));
}
byte[] out = new byte[Math.toIntExact(produced)];
MemorySegment.copy(dictBuf, JAVA_BYTE, 0, out, 0, out.length);
return new ZstdDictionary(out);
return toDictionary(dictBuf, produced, "dictionary finalisation");
}
}

/// One native buffer holding all samples back to back, plus a parallel
/// `size_t[]` of their lengths — the shape the ZDICT trainers consume.
private record FlatSamples(MemorySegment data, MemorySegment sizes, int count) {
}

private static FlatSamples flatten(Arena arena, List<byte[]> samples) {
long total = 0;
for (byte[] s : samples) {
total += s.length;
}
MemorySegment data = arena.allocate(Math.max(total, 1));
MemorySegment sizes = arena.allocate(JAVA_LONG, samples.size());
long offset = 0;
for (int i = 0; i < samples.size(); i++) {
byte[] s = samples.get(i);
MemorySegment.copy(s, 0, data, JAVA_BYTE, offset, s.length);
sizes.setAtIndex(JAVA_LONG, i, s.length);
offset += s.length;
}
return new FlatSamples(data, sizes, samples.size());
}

private static void requireNonEmpty(List<byte[]> samples, String verb) {
if (samples.isEmpty()) {
throw new ZstdException("cannot " + verb + " a dictionary from zero samples");
}
}

private static ZstdDictionary toDictionary(MemorySegment dictBuf, long produced, String what) {
if (zdictIsError(produced)) {
throw new ZstdException(what + " failed: " + zdictErrorName(produced));
}
byte[] out = new byte[Math.toIntExact(produced)];
MemorySegment.copy(dictBuf, JAVA_BYTE, 0, out, 0, out.length);
return new ZstdDictionary(out);
}

/// The dictionary id zstd stamps into frames compressed with this dictionary,
Expand Down
Loading