diff --git a/internal/storage/memtable/skiplist_test.go b/internal/storage/memtable/skiplist_test.go index 0ed9611..ade959a 100644 --- a/internal/storage/memtable/skiplist_test.go +++ b/internal/storage/memtable/skiplist_test.go @@ -3,6 +3,8 @@ package memtable import ( "bytes" "errors" + "fmt" + "sync" "testing" ) @@ -176,6 +178,104 @@ func TestSkipList_EmptyAndNil(t *testing.T) { } } +// TestSkipList_StrictConcurrency verifies that a single writer goroutine and +// multiple concurrent reader goroutines operating on the same key never observe +// a corrupted or malformed value, confirming read-write lock correctness. +func TestSkipList_StrictConcurrency(t *testing.T) { + skipList := NewSkipList(100000, 12) + var waitGroup sync.WaitGroup + key := []byte("shared-key") + + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + for i := 0; i < 1000; i++ { + val := []byte(fmt.Sprintf("val-%d", i)) + _ = skipList.Put(key, val) + } + }() + + for r := 0; r < 5; r++ { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + for i := 0; i < 1000; i++ { + val, err := skipList.Get(key) + if err != nil && !errors.Is(err, ErrKeyNotFound) { + t.Errorf("unexpected error on Get: %v", err) + } + if err == nil { + if !bytes.HasPrefix(val, []byte("val-")) { + t.Errorf("corrupted value: %s", val) + } + } + } + }() + } + + waitGroup.Wait() +} + +// TestSkipList_Concurrency is a broad stress test that runs concurrent Puts, +// Deletes, and Iterator traversals across disjoint key ranges simultaneously, +// verifying that no deadlock, panic, or data corruption occurs under contention. +func TestSkipList_Concurrency(t *testing.T) { + skipList := NewSkipList(100000, 12) + var waitGroup sync.WaitGroup + + for i := 0; i < 100; i++ { + waitGroup.Add(1) + go func(id int) { + defer waitGroup.Done() + key := []byte(fmt.Sprintf("key-%03d", id)) + val := []byte(fmt.Sprintf("val-%03d", id)) + if err := skipList.Put(key, val); err != nil { + t.Errorf("put failed for %s: %v", key, err) + } + }(i) + } + + for i := 100; i < 200; i++ { + waitGroup.Add(1) + go func(id int) { + defer waitGroup.Done() + key := []byte(fmt.Sprintf("key-%03d", id)) + if err := skipList.Delete(key); err != nil { + t.Errorf("delete failed for %s: %v", key, err) + } + }(i) + } + + for i := 0; i < 20; i++ { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + iterator := skipList.NewIterator() + for iterator.Valid() { + _, _, _ = iterator.Next() + } + }() + } + + waitGroup.Wait() + + for i := 0; i < 100; i++ { + key := []byte(fmt.Sprintf("key-%03d", i)) + expected := []byte(fmt.Sprintf("val-%03d", i)) + got, err := skipList.Get(key) + if err != nil || !bytes.Equal(got, expected) { + t.Fatalf("unexpected state for %s: got (%q, %v), want (%q, nil)", key, got, err, expected) + } + } + + for i := 100; i < 200; i++ { + key := []byte(fmt.Sprintf("key-%03d", i)) + if _, err := skipList.Get(key); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected ErrKeyNotFound for deleted/tombstoned key %s, got %v", key, err) + } + } +} + // TestSkipList_SortedOrder verifies that the iterator always returns keys in // ascending lexicographic order, regardless of insertion order. func TestSkipList_SortedOrder(t *testing.T) { diff --git a/internal/storage/wal/.gitkeep b/internal/storage/wal/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go new file mode 100644 index 0000000..eb71327 --- /dev/null +++ b/internal/storage/wal/errors.go @@ -0,0 +1,49 @@ +// Package wal implements a Write-Ahead Log (WAL) for penguin-db. +// It supports log appending, rotation, record serialization, and replay recovery. +package wal + +import ( + "errors" + "math" +) + +var ( + // ErrInvalidCRC is returned when a record's checksum does not match its payload. + ErrInvalidCRC = errors.New("corrupt wal record: crc32 mismatch") + + // ErrTruncated is returned when the log record has an unexpected size or EOF is reached mid-record. + ErrTruncated = errors.New("corrupt wal record: truncated payload") + + // ErrEmptyKey is returned when attempting to write a log record with a zero-length key. + ErrEmptyKey = errors.New("wal record rejected: key must not be empty") + + // ErrInvalidOpcode is returned when a record carries an opcode that is not + // recognized by the WAL format. Persisting such a record would succeed but + // the entry would be silently skipped during recovery replay. + ErrInvalidOpcode = errors.New("wal record rejected: unrecognized opcode") + + // ErrKeyTooLarge is returned when the key exceeds the maximum representable + // length (math.MaxUint16 bytes) in the on-disk frame format. + ErrKeyTooLarge = errors.New("wal record rejected: key length exceeds maximum of " + uitoa(math.MaxUint16) + " bytes") + + // ErrFrameTooLarge is returned when the total serialized frame size exceeds + // the maximum representable size (math.MaxUint32 bytes) in the on-disk format. + ErrFrameTooLarge = errors.New("wal record rejected: frame size exceeds maximum of " + uitoa(math.MaxUint32) + " bytes") +) + +// uitoa converts an unsigned integer to its decimal string representation. +// Used to embed numeric limits in error sentinel messages at init time without +// depending on strconv or fmt. +func uitoa(val uint64) string { + if val == 0 { + return "0" + } + var buf [20]byte // max uint64 is 20 digits + i := len(buf) + for val > 0 { + i-- + buf[i] = byte(val%10) + '0' + val /= 10 + } + return string(buf[i:]) +} diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go new file mode 100644 index 0000000..00b34fb --- /dev/null +++ b/internal/storage/wal/record.go @@ -0,0 +1,126 @@ +package wal + +import ( + "encoding/binary" + "hash/crc32" + "math" +) + +const ( + // OpcodePut represents a put/insert operation in the WAL. + OpcodePut uint8 = 0 + // OpcodeDelete represents a delete operation in the WAL. + OpcodeDelete uint8 = 1 +) + +// Record represents a single change logged in the WAL, wrapping an operation +// type (Opcode), key, and value payload. +type Record struct { + Opcode uint8 + Key []byte + Value []byte +} + +// Define sizes for each field in the frame +const ( + checksumSize = 4 + frameSizeSize = 4 + opcodeSize = 1 + keyLengthSize = 2 + + // Fixed header size is the sum of all fixed-length fields + fixedHeaderSize = checksumSize + frameSizeSize + opcodeSize + keyLengthSize +) + +// Define offsets for each field to eliminate magic slice indices +const ( + checksumOffset = 0 + frameSizeOffset = checksumOffset + checksumSize + opcodeOffset = frameSizeOffset + frameSizeSize + keyLengthOffset = opcodeOffset + opcodeSize + keyOffset = keyLengthOffset + keyLengthSize +) + +// Marshal serializes the Record into a binary frame. +// +// Frame Layout: +// +-------------+-------------+----------+------------+-----------+-----------+ +// | Checksum | Frame Size | Opcode | Key Length | Key | Value | +// | (4 bytes) | (4 bytes) | (1 byte) | (2 bytes) | (n bytes) | (m bytes) | +// +-------------+-------------+----------+------------+-----------+-----------+ +// +// Note: The Checksum (CRC32) covers all bytes starting from the Frame Size. +// +// Marshal returns ErrKeyTooLarge if the key exceeds math.MaxUint16 bytes, or +// ErrFrameTooLarge if the total frame size exceeds math.MaxUint32 bytes, since +// these sizes cannot be represented in the on-disk field widths. +func (record *Record) Marshal() ([]byte, error) { + keyLen := len(record.Key) + valLen := len(record.Value) + + if keyLen > math.MaxUint16 { + return nil, ErrKeyTooLarge + } + + totalFrameSizeBytes := fixedHeaderSize + keyLen + valLen + + if totalFrameSizeBytes > math.MaxUint32 { + return nil, ErrFrameTooLarge + } + + frameBuffer := make([]byte, totalFrameSizeBytes) + + binary.LittleEndian.PutUint32(frameBuffer[frameSizeOffset:opcodeOffset], uint32(totalFrameSizeBytes)) + frameBuffer[opcodeOffset] = record.Opcode + binary.LittleEndian.PutUint16(frameBuffer[keyLengthOffset:keyOffset], uint16(keyLen)) + + copy(frameBuffer[keyOffset:], record.Key) + valueOffset := keyOffset + keyLen + copy(frameBuffer[valueOffset:], record.Value) + + calculatedChecksum := crc32.ChecksumIEEE(frameBuffer[frameSizeOffset:]) + binary.LittleEndian.PutUint32(frameBuffer[checksumOffset:frameSizeOffset], calculatedChecksum) + + return frameBuffer, nil +} + +// UnmarshalRecord deserializes a raw binary frame and reconstructs the original +// Record. It validates the data integrity using the CRC checksum and performs +// bounds checking on payload lengths. +func UnmarshalRecord(frameData []byte) (*Record, error) { + if len(frameData) < fixedHeaderSize { + return nil, ErrTruncated + } + + storedChecksum := binary.LittleEndian.Uint32(frameData[checksumOffset:frameSizeOffset]) + calculatedChecksum := crc32.ChecksumIEEE(frameData[frameSizeOffset:]) + + if storedChecksum != calculatedChecksum { + return nil, ErrInvalidCRC + } + + extractedOpcode := frameData[opcodeOffset] + extractedKeyLength := binary.LittleEndian.Uint16(frameData[keyLengthOffset:keyOffset]) + + if len(frameData) < fixedHeaderSize+int(extractedKeyLength) { + return nil, ErrTruncated + } + + extractedKey := make([]byte, extractedKeyLength) + copy(extractedKey, frameData[keyOffset:keyOffset+int(extractedKeyLength)]) + + extractedValueLength := len(frameData) - (fixedHeaderSize + int(extractedKeyLength)) + var extractedValue []byte + + if extractedValueLength > 0 { + extractedValue = make([]byte, extractedValueLength) + valueOffset := keyOffset + int(extractedKeyLength) + copy(extractedValue, frameData[valueOffset:]) + } + + return &Record{ + Opcode: extractedOpcode, + Key: extractedKey, + Value: extractedValue, + }, nil +} diff --git a/internal/storage/wal/record_test.go b/internal/storage/wal/record_test.go new file mode 100644 index 0000000..9e72300 --- /dev/null +++ b/internal/storage/wal/record_test.go @@ -0,0 +1,321 @@ +package wal + +import ( + "bytes" + "encoding/binary" + "errors" + "hash/crc32" + "math" + "testing" +) + +// mustMarshal is a test helper that calls Marshal and fails the test on error. +func mustMarshal(t *testing.T, r *Record) []byte { + t.Helper() + frame, err := r.Marshal() + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + return frame +} + +// buildValidFrame is a helper that constructs a valid binary frame from a Record. +func buildValidFrame(r *Record) []byte { + frame, err := r.Marshal() + if err != nil { + panic("buildValidFrame: " + err.Error()) + } + return frame +} + +// TestMarshal_FrameLayout tests that the layout of the marshaled frame meets +// specification requirements (proper size, opcode location, and key/value placement). +func TestMarshal_FrameLayout(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} + frame := mustMarshal(t, r) + + wantSize := fixedHeaderSize + len(r.Key) + len(r.Value) + if len(frame) != wantSize { + t.Fatalf("frame length = %d, want %d", len(frame), wantSize) + } + + storedSize := binary.LittleEndian.Uint32(frame[frameSizeOffset:opcodeOffset]) + if int(storedSize) != wantSize { + t.Errorf("stored size field = %d, want %d", storedSize, wantSize) + } + + if frame[opcodeOffset] != OpcodePut { + t.Errorf("opcode byte = %d, want %d", frame[opcodeOffset], OpcodePut) + } + + storedKeyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]) + if int(storedKeyLen) != len(r.Key) { + t.Errorf("stored key length = %d, want %d", storedKeyLen, len(r.Key)) + } + + if !bytes.Equal(frame[keyOffset:keyOffset+len(r.Key)], r.Key) { + t.Error("key bytes mismatch in frame") + } + + if !bytes.Equal(frame[keyOffset+len(r.Key):], r.Value) { + t.Error("value bytes mismatch in frame") + } +} + +// TestMarshal_CRCCoversPayload verifies that the CRC-32 checksum is calculated +// correctly over the rest of the frame payload. +func TestMarshal_CRCCoversPayload(t *testing.T) { + r := &Record{Opcode: OpcodeDelete, Key: []byte("k"), Value: nil} + frame := mustMarshal(t, r) + + storedCRC := binary.LittleEndian.Uint32(frame[checksumOffset:frameSizeOffset]) + if storedCRC != crc32.ChecksumIEEE(frame[frameSizeOffset:]) { + t.Errorf("CRC mismatch: stored=%d calculated=%d", storedCRC, crc32.ChecksumIEEE(frame[frameSizeOffset:])) + } +} + +// TestMarshal_OpcodeDelete verifies that deletion records marshal with OpcodeDelete. +func TestMarshal_OpcodeDelete(t *testing.T) { + r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} + frame := mustMarshal(t, r) + if frame[opcodeOffset] != OpcodeDelete { + t.Errorf("opcode = %d, want OpcodeDelete (%d)", frame[opcodeOffset], OpcodeDelete) + } +} + +// TestMarshal_ZeroLengthKey tests marshaling records with empty or nil keys. +func TestMarshal_ZeroLengthKey(t *testing.T) { + cases := []struct { + name string + key []byte + }{ + {"empty", []byte{}}, + {"nil", nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + frame := mustMarshal(t, &Record{Opcode: OpcodePut, Key: tc.key, Value: []byte("v")}) + if keyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]); keyLen != 0 { + t.Errorf("key length = %d, want 0", keyLen) + } + }) + } +} + +// TestMarshal_ZeroLengthValue tests marshaling records with empty or nil values. +func TestMarshal_ZeroLengthValue(t *testing.T) { + cases := []struct { + name string + value []byte + }{ + {"empty", []byte{}}, + {"nil", nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: tc.value} + frame := mustMarshal(t, r) + wantSize := fixedHeaderSize + len(r.Key) + if len(frame) != wantSize { + t.Errorf("frame length = %d, want %d", len(frame), wantSize) + } + }) + } +} + +// TestMarshal_LargePayload tests marshalling records with large keys and values. +func TestMarshal_LargePayload(t *testing.T) { + key := bytes.Repeat([]byte("k"), 1024) + value := bytes.Repeat([]byte("v"), 4096) + r := &Record{Opcode: OpcodePut, Key: key, Value: value} + frame := mustMarshal(t, r) + wantSize := fixedHeaderSize + len(key) + len(value) + if len(frame) != wantSize { + t.Errorf("frame size = %d, want %d", len(frame), wantSize) + } + storedCRC := binary.LittleEndian.Uint32(frame[checksumOffset:frameSizeOffset]) + if storedCRC != crc32.ChecksumIEEE(frame[frameSizeOffset:]) { + t.Error("CRC invalid for large payload") + } +} + +// TestMarshal_BinaryKeyAndValue ensures that arbitrary binary data is preserved +// during marshalling and unmarshalling. +func TestMarshal_BinaryKeyAndValue(t *testing.T) { + key := []byte{0x00, 0xFF, 0x7F, 0x80} + value := []byte{0x01, 0x02, 0x03} + r := &Record{Opcode: OpcodePut, Key: key, Value: value} + frame := mustMarshal(t, r) + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("UnmarshalRecord failed: %v", err) + } + if !bytes.Equal(recovered.Key, key) { + t.Errorf("key mismatch: got %v, want %v", recovered.Key, key) + } + if !bytes.Equal(recovered.Value, value) { + t.Errorf("value mismatch: got %v, want %v", recovered.Value, value) + } +} + +// TestMarshal_Idempotent verifies that Marshal output is identical for successive calls. +func TestMarshal_Idempotent(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("idempotent"), Value: []byte("yes")} + frame1 := mustMarshal(t, r) + frame2 := mustMarshal(t, r) + if !bytes.Equal(frame1, frame2) { + t.Error("Marshal is not idempotent for the same record") + } +} + +// TestMarshal_KeyTooLarge verifies that keys exceeding math.MaxUint16 bytes are rejected. +func TestMarshal_KeyTooLarge(t *testing.T) { + oversizedKey := make([]byte, math.MaxUint16+1) + r := &Record{Opcode: OpcodePut, Key: oversizedKey, Value: []byte("v")} + _, err := r.Marshal() + if !errors.Is(err, ErrKeyTooLarge) { + t.Errorf("expected ErrKeyTooLarge, got %v", err) + } +} + +// TestMarshal_MaxKeyLength verifies that a key of exactly math.MaxUint16 bytes is accepted. +func TestMarshal_MaxKeyLength(t *testing.T) { + maxKey := make([]byte, math.MaxUint16) + r := &Record{Opcode: OpcodePut, Key: maxKey, Value: nil} + frame, err := r.Marshal() + if err != nil { + t.Fatalf("Marshal rejected max-length key: %v", err) + } + storedKeyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]) + if storedKeyLen != math.MaxUint16 { + t.Errorf("stored key length = %d, want %d", storedKeyLen, math.MaxUint16) + } +} + +// TestUnmarshal_RoundTrip_Put checks that a serialized Put record can be accurately reconstructed. +func TestUnmarshal_RoundTrip_Put(t *testing.T) { + original := &Record{Opcode: OpcodePut, Key: []byte("name"), Value: []byte("penguin")} + frame := buildValidFrame(original) + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if recovered.Opcode != original.Opcode { + t.Errorf("opcode: got %d, want %d", recovered.Opcode, original.Opcode) + } + if !bytes.Equal(recovered.Key, original.Key) { + t.Errorf("key: got %q, want %q", recovered.Key, original.Key) + } + if !bytes.Equal(recovered.Value, original.Value) { + t.Errorf("value: got %q, want %q", recovered.Value, original.Value) + } +} + +// TestUnmarshal_RoundTrip_Delete checks that a serialized Delete record can be accurately reconstructed. +func TestUnmarshal_RoundTrip_Delete(t *testing.T) { + original := &Record{Opcode: OpcodeDelete, Key: []byte("to-remove"), Value: nil} + frame := buildValidFrame(original) + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if recovered.Opcode != OpcodeDelete { + t.Errorf("opcode: got %d, want OpcodeDelete", recovered.Opcode) + } + if !bytes.Equal(recovered.Key, original.Key) { + t.Errorf("key mismatch") + } +} + +// TestUnmarshal_TruncatedFrame_TooShort verifies that short inputs return ErrTruncated. +func TestUnmarshal_TruncatedFrame_TooShort(t *testing.T) { + cases := [][]byte{ + {}, + {0x01}, + {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}, + } + for _, data := range cases { + _, err := UnmarshalRecord(data) + if !errors.Is(err, ErrTruncated) { + t.Errorf("input len=%d: expected ErrTruncated, got %v", len(data), err) + } + } +} + +// TestUnmarshal_MinimalFrame_EmptyKeyNoValue verifies parsing of minimally sized frames. +func TestUnmarshal_MinimalFrame_EmptyKeyNoValue(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: nil} + frame := mustMarshal(t, r) + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error for minimal frame: %v", err) + } + if len(recovered.Key) != 0 { + t.Errorf("expected empty key, got %v", recovered.Key) + } +} + +// TestUnmarshal_CRCCorruption tests that altered headers or payloads result in checksum errors. +func TestUnmarshal_CRCCorruption(t *testing.T) { + cases := []struct { + name string + corruptFn func([]byte) + }{ + {"payload byte", func(f []byte) { f[keyOffset+1] ^= 0x01 }}, + {"crc field", func(f []byte) { f[checksumOffset] ^= 0xFF }}, + {"size field", func(f []byte) { binary.LittleEndian.PutUint32(f[frameSizeOffset:opcodeOffset], 9999) }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: []byte("val")} + frame := mustMarshal(t, r) + tc.corruptFn(frame) + _, err := UnmarshalRecord(frame) + if !errors.Is(err, ErrInvalidCRC) { + t.Errorf("expected ErrInvalidCRC, got %v", err) + } + }) + } +} + +// TestUnmarshal_KeyLengthExceedsFrame checks that key lengths exceeding frame limits are rejected. +func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("ab"), Value: []byte("v")} + frame := mustMarshal(t, r) + + binary.LittleEndian.PutUint16(frame[keyLengthOffset:keyOffset], 65535) + newCRC := crc32.ChecksumIEEE(frame[frameSizeOffset:]) + binary.LittleEndian.PutUint32(frame[checksumOffset:frameSizeOffset], newCRC) + + _, err := UnmarshalRecord(frame) + if !errors.Is(err, ErrTruncated) { + t.Errorf("expected ErrTruncated when keyLen > frame size, got %v", err) + } +} + +// TestUnmarshal_NilValueForDeleteRecord checks that deleted records correctly parse nil values. +func TestUnmarshal_NilValueForDeleteRecord(t *testing.T) { + r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} + frame := mustMarshal(t, r) + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if recovered.Value != nil { + t.Errorf("expected nil Value for delete record, got %v", recovered.Value) + } +} + +// TestOpcodeValues validates correctness of package opcode constants. +func TestOpcodeValues(t *testing.T) { + if OpcodePut != 0 { + t.Errorf("OpcodePut = %d, want 0", OpcodePut) + } + if OpcodeDelete != 1 { + t.Errorf("OpcodeDelete = %d, want 1", OpcodeDelete) + } +} diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go new file mode 100644 index 0000000..8302aa3 --- /dev/null +++ b/internal/storage/wal/recovery.go @@ -0,0 +1,172 @@ +package wal + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "sort" +) + +// RecordConsumer defines the minimal interface for in-memory storage engine components +// that can consume replayed WAL records during recovery. +type RecordConsumer interface { + Put(key, value []byte) error + Delete(key []byte) error +} + +// Replay scans the specified WAL directory, identifies all segment files +// matching the *.wal pattern, and replays their logged operations onto the +// target RecordConsumer. It returns the highest segment ID found or 1 if fresh. +func Replay(directory string, recordConsumer RecordConsumer) (int, error) { + slog.Debug("starting WAL recovery sequence", "directory", directory) + + entries, err := os.ReadDir(directory) + if err != nil { + if os.IsNotExist(err) { + slog.Debug("WAL directory does not exist, starting fresh", "directory", directory) + return 1, nil + } + return 0, fmt.Errorf("failed to scan WAL directory during recovery boot: %w", err) + } + + var walFiles []string + for _, entry := range entries { + if !entry.IsDir() && filepath.Ext(entry.Name()) == ".wal" { + walFiles = append(walFiles, entry.Name()) + } + } + + slog.Debug("found WAL segments for replay", "count", len(walFiles)) + sort.Strings(walFiles) + + highestSegmentID := 0 + + for _, fileName := range walFiles { + var segmentID int + if n, _ := fmt.Sscanf(fileName, "%d.wal", &segmentID); n != 1 { + slog.Debug("skipping WAL file with unexpected name format", "file", fileName) + continue + } + if segmentID > highestSegmentID { + highestSegmentID = segmentID + } + + filePath := filepath.Join(directory, fileName) + slog.Debug("replaying WAL segment", "segment_id", segmentID, "file", fileName) + + if err := replayFile(filePath, recordConsumer); err != nil { + return 0, fmt.Errorf("critical failure while replaying segment %s: %w", fileName, err) + } + } + + if highestSegmentID == 0 { + highestSegmentID = 1 + } + + slog.Debug("WAL recovery complete", "highest_segment_id", highestSegmentID) + return highestSegmentID, nil +} + +// replayFile opens a single WAL segment, reads it frame-by-frame, validates +// check-sums and frame sizes, and applies the operations onto the RecordConsumer. +// If it encounters corruption or a partial write, it truncates the segment. +func replayFile(filePath string, recordConsumer RecordConsumer) (err error) { + file, err := os.OpenFile(filePath, os.O_RDWR, 0o644) + if err != nil { + return fmt.Errorf("unable to open WAL segment for reading: %w", err) + } + defer func() { + if closeErr := file.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("failed to close WAL segment %s: %w", filePath, closeErr) + } + }() + + var validBytes int64 + var recordsRecovered int + headerBuffer := make([]byte, 8) + + truncateAndSync := func() error { + if truncErr := file.Truncate(validBytes); truncErr != nil { + return truncErr + } + return file.Sync() + } + + for { + _, err := io.ReadFull(file, headerBuffer) + if err != nil { + if errors.Is(err, io.EOF) { + slog.Debug("reached clean EOF for WAL segment", + "file", filepath.Base(filePath), + "records_recovered", recordsRecovered) + break + } + if errors.Is(err, io.ErrUnexpectedEOF) { + slog.Debug("unexpected EOF in header, truncating segment", + "file", filepath.Base(filePath), + "valid_bytes", validBytes) + + return truncateAndSync() + } + return fmt.Errorf("unexpected disk error reading frame header: %w", err) + } + + totalFrameSizeBytes := binary.LittleEndian.Uint32(headerBuffer[4:8]) + if totalFrameSizeBytes < 8 || totalFrameSizeBytes > 128*1024*1024 { + slog.Debug("invalid frame size in header, truncating segment", + "file", filepath.Base(filePath), + "frame_size", totalFrameSizeBytes, + "valid_bytes", validBytes) + + return truncateAndSync() + } + payloadSizeBytes := totalFrameSizeBytes - 8 + + payloadBuffer := make([]byte, payloadSizeBytes) + _, err = io.ReadFull(file, payloadBuffer) + if err != nil { + slog.Debug("unexpected EOF in payload, truncating segment", + "file", filepath.Base(filePath), + "valid_bytes", validBytes) + + return truncateAndSync() + } + + fullFrame := make([]byte, 8+payloadSizeBytes) + copy(fullFrame[:8], headerBuffer) + copy(fullFrame[8:], payloadBuffer) + + record, err := UnmarshalRecord(fullFrame) + if err != nil { + if errors.Is(err, ErrInvalidCRC) || errors.Is(err, ErrTruncated) { + slog.Debug("corrupted frame detected, truncating segment", + "file", filepath.Base(filePath), + "valid_bytes", validBytes, + "error", err) + + return truncateAndSync() + } + return fmt.Errorf("failed to decode valid frame payload: %w", err) + } + + switch record.Opcode { + case OpcodePut: + if putErr := recordConsumer.Put(record.Key, record.Value); putErr != nil { + return fmt.Errorf("memtable rejected recovered put operation: %w", putErr) + } + case OpcodeDelete: + if delErr := recordConsumer.Delete(record.Key); delErr != nil { + return fmt.Errorf("memtable rejected recovered delete operation: %w", delErr) + } + } + + validBytes += int64(totalFrameSizeBytes) + recordsRecovered++ + } + + return nil +} diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go new file mode 100644 index 0000000..acfbfc4 --- /dev/null +++ b/internal/storage/wal/recovery_test.go @@ -0,0 +1,548 @@ +package wal + +import ( + "encoding/binary" + "fmt" + "hash/crc32" + "os" + "path/filepath" + "testing" +) + +// mockRecordConsumer implements the RecordConsumer interface for tracking replayed actions +// and simulating write failures during recovery tests. +type mockRecordConsumer struct { + puts map[string][]byte + deletes []string + putErr error + delErr error +} + +// newMockRecordConsumer constructs an empty mockRecordConsumer. +func newMockRecordConsumer() *mockRecordConsumer { + return &mockRecordConsumer{puts: make(map[string][]byte)} +} + +// Put records a put operation or returns an injected failure error. +func (m *mockRecordConsumer) Put(key, value []byte) error { + if m.putErr != nil { + return m.putErr + } + m.puts[string(key)] = value + return nil +} + +// Delete records a delete operation or returns an injected failure error. +func (m *mockRecordConsumer) Delete(key []byte) error { + if m.delErr != nil { + return m.delErr + } + m.deletes = append(m.deletes, string(key)) + return nil +} + +// writeRecordsToFile serializes and appends a list of Records to a segment file. +func writeRecordsToFile(t *testing.T, path string, records []*Record) { + t.Helper() + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + t.Fatalf("writeRecordsToFile open: %v", err) + } + defer f.Close() + for _, r := range records { + if _, err := f.Write(mustMarshal(t, r)); err != nil { + t.Fatalf("writeRecordsToFile write: %v", err) + } + } +} + +// segmentPath constructs the absolute filepath for a given segment ID. +func segmentPath(dir string, id int) string { + return filepath.Join(dir, fmt.Sprintf("%06d.wal", id)) +} + +// TestReplay_NonExistentDirectory_ReturnsFreshSegmentID verifies recovery behaves +// correctly if the WAL directory does not exist. +func TestReplay_NonExistentDirectory_ReturnsFreshSegmentID(t *testing.T) { + dir := filepath.Join(t.TempDir(), "no-such-wal") + mem := newMockRecordConsumer() + + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } + if len(mem.puts) != 0 || len(mem.deletes) != 0 { + t.Error("memtable should be empty for a fresh start") + } +} + +// TestReplay_EmptyDirectory_ReturnsFreshSegmentID checks recovery on an empty directory. +func TestReplay_EmptyDirectory_ReturnsFreshSegmentID(t *testing.T) { + dir := t.TempDir() + mem := newMockRecordConsumer() + + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +// TestReplay_NonWALFilesIgnored checks that recovery ignores files without .wal extensions. +func TestReplay_NonWALFilesIgnored(t *testing.T) { + dir := t.TempDir() + for _, name := range []string{"data.sst", "manifest", "000001.log", "LOCK"} { + if err := os.WriteFile(filepath.Join(dir, name), []byte("data"), 0644); err != nil { + t.Fatal(err) + } + } + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +// TestReplay_SingleSegment_AllPuts checks that a simple segment with Put operations is fully replayed. +func TestReplay_SingleSegment_AllPuts(t *testing.T) { + dir := t.TempDir() + records := []*Record{ + {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("1")}, + {Opcode: OpcodePut, Key: []byte("b"), Value: []byte("2")}, + {Opcode: OpcodePut, Key: []byte("c"), Value: []byte("3")}, + } + writeRecordsToFile(t, segmentPath(dir, 1), records) + + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } + for _, r := range records { + if string(mem.puts[string(r.Key)]) != string(r.Value) { + t.Errorf("key %q: got value %q, want %q", r.Key, mem.puts[string(r.Key)], r.Value) + } + } +} + +// TestReplay_SingleSegment_Deletes checks that Delete operations are replayed. +func TestReplay_SingleSegment_Deletes(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("x"), Value: []byte("val")}, + {Opcode: OpcodeDelete, Key: []byte("x")}, + }) + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(mem.deletes) != 1 || mem.deletes[0] != "x" { + t.Errorf("expected one delete for 'x', got %v", mem.deletes) + } +} + +// TestReplay_MultipleSegments_ReplayedInOrder verifies segment logs are recovered in sequential order. +func TestReplay_MultipleSegments_ReplayedInOrder(t *testing.T) { + dir := t.TempDir() + + writeRecordsToFile(t, segmentPath(dir, 3), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg3")}, + }) + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg1")}, + }) + writeRecordsToFile(t, segmentPath(dir, 2), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg2")}, + }) + + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(mem.puts["k"]) != "seg3" { + t.Errorf("key 'k' value = %q, want 'seg3'", mem.puts["k"]) + } + if nextID != 3 { + t.Errorf("nextID = %d, want 3", nextID) + } +} + +// TestReplay_ReturnsHighestSegmentID verifies recovery identifies and returns the highest active segment ID. +func TestReplay_ReturnsHighestSegmentID(t *testing.T) { + dir := t.TempDir() + for _, id := range []int{1, 5, 10} { + writeRecordsToFile(t, segmentPath(dir, id), []*Record{ + {Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", id)), Value: []byte("v")}, + }) + } + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 10 { + t.Errorf("nextID = %d, want 10", nextID) + } +} + +// TestReplay_UnknownOpcode_Ignored verifies that records with invalid opcodes are ignored during replay. +func TestReplay_UnknownOpcode_Ignored(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: 99, Key: []byte("k"), Value: []byte("v")}, + }) + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Errorf("unexpected error for unknown opcode: %v", err) + } + if _, ok := mem.puts["k"]; ok { + t.Error("unknown opcode record should not be applied to memtable") + } +} + +// TestReplay_CorruptedCRC_TruncatesFile checks that recovery detects CRC mismatches and truncates. +func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + validBytes := mustMarshal(t, good) + writeRecordsToFile(t, path, []*Record{good}) + + badFrame := mustMarshal(t, &Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("x")}) + badFrame[12] ^= 0xFF + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + t.Fatal(err) + } + f.Write(badFrame) + f.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied to memtable") + } + if _, ok := mem.puts["bad"]; ok { + t.Error("corrupt record was applied to memtable") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +// TestReplay_TruncatedHeader_TruncatesFile checks truncation when the frame header is truncated. +func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("ok"), Value: []byte("v")} + validBytes := mustMarshal(t, good) + writeRecordsToFile(t, path, []*Record{good}) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0xFF}) + f.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["ok"]; !ok { + t.Error("valid record was not applied") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +// TestReplay_TruncatedPayload_TruncatesFile checks truncation when the frame payload is truncated. +func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("safe"), Value: []byte("data")} + validBytes := mustMarshal(t, good) + writeRecordsToFile(t, path, []*Record{good}) + + fakeKey := []byte("payload-cut") + fakeSizeBytes := uint32(fixedHeaderSize + len(fakeKey) + 100) + hdr := make([]byte, checksumSize+frameSizeSize) + binary.LittleEndian.PutUint32(hdr[frameSizeOffset:frameSizeOffset+frameSizeSize], fakeSizeBytes) + binary.LittleEndian.PutUint32(hdr[checksumOffset:checksumOffset+checksumSize], crc32.ChecksumIEEE(hdr[frameSizeOffset:])) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write(hdr) + f.Write(fakeKey[:3]) + f.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["safe"]; !ok { + t.Error("valid record before truncation point was not applied") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +// TestReplay_EmptySegmentFile_NoError checks that an empty WAL segment is handled gracefully. +func TestReplay_EmptySegmentFile_NoError(t *testing.T) { + dir := t.TempDir() + os.WriteFile(segmentPath(dir, 1), []byte{}, 0644) + + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error for empty WAL file: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +// TestReplay_MemTablePutError_PropagatesError checks failure propagation on MemTable Put errors. +func TestReplay_MemTablePutError_PropagatesError(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")}, + }) + + mem := newMockRecordConsumer() + mem.putErr = fmt.Errorf("memtable full") + + if _, err := Replay(dir, mem); err == nil { + t.Fatal("expected error when memtable.Put fails, got nil") + } +} + +// TestReplay_MemTableDeleteError_PropagatesError checks failure propagation on MemTable Delete errors. +func TestReplay_MemTableDeleteError_PropagatesError(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodeDelete, Key: []byte("k")}, + }) + + mem := newMockRecordConsumer() + mem.delErr = fmt.Errorf("read-only memtable") + + if _, err := Replay(dir, mem); err == nil { + t.Fatal("expected error when memtable.Delete fails, got nil") + } +} + +// TestReplay_MixedPutsAndDeletes_CorrectOrder verifies interleaved Puts/Deletes replay in correct order. +func TestReplay_MixedPutsAndDeletes_CorrectOrder(t *testing.T) { + dir := t.TempDir() + mem := newMockRecordConsumer() + + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("1")}, + {Opcode: OpcodePut, Key: []byte("b"), Value: []byte("2")}, + {Opcode: OpcodeDelete, Key: []byte("a")}, + {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("3")}, + }) + + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if string(mem.puts["a"]) != "3" { + t.Errorf("key 'a' value = %q, want '3'", mem.puts["a"]) + } + if string(mem.puts["b"]) != "2" { + t.Errorf("key 'b' value = %q, want '2'", mem.puts["b"]) + } + + found := false + for _, d := range mem.deletes { + if d == "a" { + found = true + } + } + if !found { + t.Error("expected a Delete('a') call to memtable") + } +} + +// TestReplay_SegmentsSortedNumerically checks sorting logic on WAL segments with double digits. +func TestReplay_SegmentsSortedNumerically(t *testing.T) { + dir := t.TempDir() + + for i := 1; i <= 12; i++ { + writeRecordsToFile(t, segmentPath(dir, i), []*Record{ + {Opcode: OpcodePut, Key: []byte("seq"), Value: []byte(fmt.Sprintf("%d", i))}, + }) + } + + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 12 { + t.Errorf("nextID = %d, want 12", nextID) + } + if string(mem.puts["seq"]) != "12" { + t.Errorf("value for 'seq' = %q, want '12'", mem.puts["seq"]) + } +} + +// TestReplay_SubdirectoriesAreIgnored verifies subdirectories inside the WAL directory are skipped. +func TestReplay_SubdirectoriesAreIgnored(t *testing.T) { + dir := t.TempDir() + subDir := filepath.Join(dir, "archive.wal") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatal(err) + } + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")}, + }) + + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +// TestReplay_InvalidFrameSize_TooSmall_TruncatesFile verifies truncation on underflowing frame sizes. +func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + validBytes := mustMarshal(t, good) + writeRecordsToFile(t, path, []*Record{good}) + + hdr := make([]byte, checksumSize+frameSizeSize) + binary.LittleEndian.PutUint32(hdr[frameSizeOffset:frameSizeOffset+frameSizeSize], 7) + binary.LittleEndian.PutUint32(hdr[checksumOffset:checksumOffset+checksumSize], crc32.ChecksumIEEE(hdr[frameSizeOffset:])) + + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + t.Fatal(err) + } + f.Write(hdr) + f.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +// TestReplay_InvalidFrameSize_TooLarge_TruncatesFile verifies truncation on excessively large frame sizes. +func TestReplay_InvalidFrameSize_TooLarge_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + validBytes := mustMarshal(t, good) + writeRecordsToFile(t, path, []*Record{good}) + + hdr := make([]byte, checksumSize+frameSizeSize) + binary.LittleEndian.PutUint32(hdr[frameSizeOffset:frameSizeOffset+frameSizeSize], 129*1024*1024) + binary.LittleEndian.PutUint32(hdr[checksumOffset:checksumOffset+checksumSize], crc32.ChecksumIEEE(hdr[frameSizeOffset:])) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write(hdr) + f.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +// TestReplay_MalformedWALFilename_Skipped verifies that files ending in .wal but not starting with a number are skipped during replay. +func TestReplay_MalformedWALFilename_Skipped(t *testing.T) { + dir := t.TempDir() + + goodPath := segmentPath(dir, 1) + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + writeRecordsToFile(t, goodPath, []*Record{good}) + + badPath := filepath.Join(dir, "malformed.wal") + bad := &Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("val")} + f, err := os.Create(badPath) + if err != nil { + t.Fatal(err) + } + f.Write(mustMarshal(t, bad)) + f.Close() + + mem := newMockRecordConsumer() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error replaying: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied") + } + if _, ok := mem.puts["bad"]; ok { + t.Error("bad record in malformed file was replayed, but should have been skipped") + } +} diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go new file mode 100644 index 0000000..3a242d4 --- /dev/null +++ b/internal/storage/wal/writer.go @@ -0,0 +1,217 @@ +package wal + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" +) + +const MaxSegmentSizeBytes int64 = 32 * 1024 * 1024 +const MaxBatchSizeBytes int64 = 4 * 1024 * 1024 + +// commitTicket represents an ingestion task containing serialized record data +// and a channel to communicate the result of the log write and fsync. +type commitTicket struct { + frameData []byte + resultChan chan error +} + +// LogWriter manages sequential appending to WAL segment files. It coordinates +// concurrent appends, runs a batching background worker to flush writes, and +// handles file rotation when a segment exceeds its capacity. +type LogWriter struct { + directory string + activeFile *os.File + currentSegmentID int + currentSizeBytes int64 + + ingestionChannel chan *commitTicket + stateMutex sync.RWMutex + isClosed bool + + workerWaitGroup sync.WaitGroup + closeOnce sync.Once +} + +// NewLogWriter creates a new LogWriter instance, initializing the WAL directory, +// creating the initial active segment, and launching the background batch worker. +func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { + if err := os.MkdirAll(directory, 0o755); err != nil { + return nil, fmt.Errorf("failed to initialize WAL directory structure at %s: %w", directory, err) + } + + writer := &LogWriter{ + directory: directory, + currentSegmentID: nextSegmentID, + ingestionChannel: make(chan *commitTicket, 10000), + } + + if err := writer.rotateActiveFile(); err != nil { + return nil, err + } + + writer.workerWaitGroup.Add(1) + go writer.batchWorker() + + return writer, nil +} + +// rotateActiveFile closes the current active segment file after fsyncing its data, +// increments the segment ID, and opens a new segment file for write access. +func (writer *LogWriter) rotateActiveFile() error { + if writer.activeFile != nil { + if err := writer.activeFile.Sync(); err != nil { + return fmt.Errorf("failed to sync WAL segment %d during rotation: %w", writer.currentSegmentID, err) + } + if err := writer.activeFile.Close(); err != nil { + return fmt.Errorf("failed to close WAL segment %d during rotation: %w", writer.currentSegmentID, err) + } + writer.currentSegmentID++ + } + + segmentPath := filepath.Join(writer.directory, fmt.Sprintf("%06d.wal", writer.currentSegmentID)) + file, err := os.OpenFile(segmentPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("failed to open new WAL segment %s: %w", segmentPath, err) + } + + writer.activeFile = file + + info, err := file.Stat() + if err != nil { + file.Close() + return fmt.Errorf("failed to stat WAL segment %s: %w", segmentPath, err) + } + writer.currentSizeBytes = info.Size() + + return nil +} + +// Append writes a single Record into the Write-Ahead Log. It blocks until the +// record is durably persisted (written and synced) or the log is closed. +func (writer *LogWriter) Append(record *Record) error { + if len(record.Key) == 0 { + return ErrEmptyKey + } + + if record.Opcode != OpcodePut && record.Opcode != OpcodeDelete { + return ErrInvalidOpcode + } + + frame, err := record.Marshal() + if err != nil { + return err + } + + ticket := &commitTicket{ + frameData: frame, + resultChan: make(chan error, 1), + } + + writer.stateMutex.RLock() + if writer.isClosed { + writer.stateMutex.RUnlock() + return fmt.Errorf("database engine is currently shutting down, write rejected") + } + + slog.Debug("network thread: dropping record into ingestion channel", "frame_size", len(frame)) + writer.ingestionChannel <- ticket + writer.stateMutex.RUnlock() + + return <-ticket.resultChan +} + +// batchWorker runs in a background goroutine, receiving commit tickets from the +// ingestion channel, batching them, writing to the active file, and syncing them. +func (writer *LogWriter) batchWorker() { + defer writer.workerWaitGroup.Done() + + var commitBatch []*commitTicket + var writeBuffer []byte + + for ticket := range writer.ingestionChannel { + commitBatch, writeBuffer = writer.gatherBatch(ticket, commitBatch, writeBuffer) + + slog.Debug("batch worker: executing group commit", + "batch_size", len(commitBatch), + "total_bytes", len(writeBuffer)) + + writer.writeAndSyncBatch(commitBatch, writeBuffer) + } +} + +// Close closes the active WAL segment file, terminates the background worker, +// and ensures all pending writes have been durably synced to disk. +func (writer *LogWriter) Close() error { + var closeErr error + writer.closeOnce.Do(func() { + writer.stateMutex.Lock() + writer.isClosed = true + close(writer.ingestionChannel) + writer.stateMutex.Unlock() + + writer.workerWaitGroup.Wait() + + if writer.activeFile != nil { + if syncErr := writer.activeFile.Sync(); syncErr != nil { + closeErr = syncErr + } + if err := writer.activeFile.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + }) + return closeErr +} + +// writeAndSyncBatch handles the disk I/O, segment rotation, and network thread +// notification for a gathered batch of records. +func (writer *LogWriter) writeAndSyncBatch(batch []*commitTicket, buffer []byte) { + if writer.currentSizeBytes+int64(len(buffer)) > MaxSegmentSizeBytes { + if err := writer.rotateActiveFile(); err != nil { + for _, ticket := range batch { + ticket.resultChan <- err + } + return + } + } + + _, err := writer.activeFile.Write(buffer) + if err == nil { + err = writer.activeFile.Sync() + } + + if err == nil { + writer.currentSizeBytes += int64(len(buffer)) + } else { + slog.Debug("batch worker: fsync/write failed", "error", err) + } + + for _, ticket := range batch { + ticket.resultChan <- err + } +} + +// gatherBatch pulls tickets from the ingestion channel up to the MaxBatchSizeBytes limit. +// It takes the first ticket yielded by the range loop and drains the remaining buffered tickets. +func (writer *LogWriter) gatherBatch(firstTicket *commitTicket, inBatch []*commitTicket, inBuffer []byte) (outBatch []*commitTicket, outBuffer []byte) { + outBatch = inBatch[:0] + outBuffer = inBuffer[:0] + + outBatch = append(outBatch, firstTicket) + outBuffer = append(outBuffer, firstTicket.frameData...) + + pendingWrites := len(writer.ingestionChannel) + for range pendingWrites { + if int64(len(outBuffer)) >= MaxBatchSizeBytes { + break + } + ticket := <-writer.ingestionChannel + outBatch = append(outBatch, ticket) + outBuffer = append(outBuffer, ticket.frameData...) + } + + return outBatch, outBuffer +} diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go new file mode 100644 index 0000000..24ef4ca --- /dev/null +++ b/internal/storage/wal/writer_test.go @@ -0,0 +1,815 @@ +package wal + +import ( + "errors" + "fmt" + "os" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestNewLogWriter_CreatesDirectory checks that initializing LogWriter creates the target directory. +func TestNewLogWriter_CreatesDirectory(t *testing.T) { + base := t.TempDir() + dir := fmt.Sprintf("%s/nested/wal", base) + + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer w.Close() + + if _, err := os.Stat(dir); os.IsNotExist(err) { + t.Error("WAL directory was not created") + } +} + +// TestNewLogWriter_CreatesFirstSegmentFile checks that the first log segment is created upon initialization. +func TestNewLogWriter_CreatesFirstSegmentFile(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer w.Close() + + if _, err := os.Stat(segmentPath(dir, 1)); os.IsNotExist(err) { + t.Error("first segment file 000001.wal was not created") + } +} + +// TestNewLogWriter_RespectsStartSegmentID verifies that the writer uses the provided initial segment ID. +func TestNewLogWriter_RespectsStartSegmentID(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 7) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer w.Close() + + if _, err := os.Stat(segmentPath(dir, 7)); os.IsNotExist(err) { + t.Error("segment 000007.wal was not created for startSegmentID=7") + } +} + +// TestNewLogWriter_InvalidDirectory_ReturnsError checks that invalid directories return initialization errors. +func TestNewLogWriter_InvalidDirectory_ReturnsError(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "not-a-dir") + if err != nil { + t.Fatal(err) + } + tmpFile.Close() + + _, err = NewLogWriter(tmpFile.Name()+"/subdir", 1) + if err == nil { + t.Error("expected error when directory path is inside a file, got nil") + } +} + +// TestAppend_SingleRecord_WrittenToDisk checks that a single record write changes the file size on disk. +func TestAppend_SingleRecord_WrittenToDisk(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + info, err := os.Stat(segmentPath(dir, 1)) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() == 0 { + t.Error("segment file is empty after Append") + } +} + +// TestAppend_MultipleRecords_AllWritten checks that multiple sequential records are successfully appended. +func TestAppend_MultipleRecords_AllWritten(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + records := []*Record{ + {Opcode: OpcodePut, Key: []byte("k1"), Value: []byte("v1")}, + {Opcode: OpcodePut, Key: []byte("k2"), Value: []byte("v2")}, + {Opcode: OpcodeDelete, Key: []byte("k1")}, + } + for _, r := range records { + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + } + + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if _, ok := mem.puts["k2"]; !ok { + t.Error("k2 not in memtable after replay") + } + found := false + for _, d := range mem.deletes { + if d == "k1" { + found = true + } + } + if !found { + t.Error("delete for k1 not replayed") + } +} + +// TestAppend_RecordRoundtrip_ViaReplay verifies that written records are completely recoverable. +func TestAppend_RecordRoundtrip_ViaReplay(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + original := &Record{Opcode: OpcodePut, Key: []byte("penguindb"), Value: []byte("rocks")} + if err := w.Append(original); err != nil { + t.Fatalf("Append: %v", err) + } + w.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if string(mem.puts["penguindb"]) != "rocks" { + t.Errorf("expected value 'rocks', got %q", mem.puts["penguindb"]) + } +} + +// TestAppend_AfterClose_ReturnsError checks that writes are rejected after Close is called. +func TestAppend_AfterClose_ReturnsError(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")} + if err = w.Append(r); err == nil { + t.Error("expected error when Appending after Close, got nil") + } +} + +// TestAppend_EmptyKey_Rejected verifies empty key appends are rejected with ErrEmptyKey. +func TestAppend_EmptyKey_Rejected(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: []byte("v")} + if err = w.Append(r); !errors.Is(err, ErrEmptyKey) { + t.Errorf("expected ErrEmptyKey for empty key, got %v", err) + } +} + +// TestAppend_NilKey_Rejected verifies nil key appends are rejected with ErrEmptyKey. +func TestAppend_NilKey_Rejected(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: nil, Value: []byte("v")} + if err = w.Append(r); !errors.Is(err, ErrEmptyKey) { + t.Errorf("expected ErrEmptyKey for nil key, got %v", err) + } +} + +// TestAppend_EmptyValue_Allowed checks that empty/nil values are allowed in WAL entries. +func TestAppend_EmptyValue_Allowed(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte("k"), Value: []byte{}} + if err := w.Append(r); err != nil { + t.Errorf("unexpected error appending record with empty value: %v", err) + } +} + +// TestAppend_InvalidOpcode_Rejected verifies that records with unrecognized opcodes +// are rejected with ErrInvalidOpcode instead of being silently persisted. +func TestAppend_InvalidOpcode_Rejected(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + for _, opcode := range []uint8{2, 99, 255} { + r := &Record{Opcode: opcode, Key: []byte("k"), Value: []byte("v")} + if err := w.Append(r); !errors.Is(err, ErrInvalidOpcode) { + t.Errorf("opcode %d: expected ErrInvalidOpcode, got %v", opcode, err) + } + } +} + +// TestRotation_NewSegmentCreatedAfterSizeExceeded checks that segment file rotates on limit overrun. +func TestRotation_NewSegmentCreatedAfterSizeExceeded(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + w.currentSizeBytes = MaxSegmentSizeBytes - 1 + + r := &Record{Opcode: OpcodePut, Key: []byte("trigger"), Value: []byte("rotation")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + if _, err := os.Stat(segmentPath(dir, 2)); os.IsNotExist(err) { + t.Error("expected segment 000002.wal after rotation") + } +} + +// TestRotation_OldSegmentSyncedOnRotation verifies old segment files are synced upon rotation. +func TestRotation_OldSegmentSyncedOnRotation(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r1 := &Record{Opcode: OpcodePut, Key: []byte("before"), Value: []byte("rotation")} + if err := w.Append(r1); err != nil { + t.Fatalf("Append r1: %v", err) + } + + w.currentSizeBytes = MaxSegmentSizeBytes + + r2 := &Record{Opcode: OpcodePut, Key: []byte("after"), Value: []byte("rotation")} + if err := w.Append(r2); err != nil { + t.Fatalf("Append r2: %v", err) + } + + w.Close() + + for _, id := range []int{1, 2} { + info, err := os.Stat(segmentPath(dir, id)) + if err != nil { + t.Fatalf("stat segment %d: %v", id, err) + } + if info.Size() == 0 { + t.Errorf("segment %d is empty, expected data", id) + } + } +} + +// TestRotation_SizeResetAfterRotation verifies that segment size tracking resets after file rotation. +func TestRotation_SizeResetAfterRotation(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + w.currentSizeBytes = MaxSegmentSizeBytes + r := &Record{Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + if w.currentSizeBytes >= MaxSegmentSizeBytes { + t.Errorf("currentSizeBytes = %d after rotation, expected < %d", + w.currentSizeBytes, MaxSegmentSizeBytes) + } +} + +// TestRotation_ReopenedSegment_SizeAccountedFor verifies that when a writer +// reopens an existing segment (e.g. after recovery), it seeds currentSizeBytes +// from the file's actual size so the rotation threshold isn't silently bypassed. +func TestRotation_ReopenedSegment_SizeAccountedFor(t *testing.T) { + dir := t.TempDir() + + // Phase 1: Write data to segment 1, then close. + w1, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter (phase 1): %v", err) + } + for i := 0; i < 10; i++ { + r := &Record{Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", i)), Value: []byte("v")} + if err := w1.Append(r); err != nil { + t.Fatalf("Append (phase 1): %v", err) + } + } + if err := w1.Close(); err != nil { + t.Fatalf("Close (phase 1): %v", err) + } + + // Get actual file size on disk. + info, err := os.Stat(segmentPath(dir, 1)) + if err != nil { + t.Fatalf("Stat: %v", err) + } + fileSize := info.Size() + if fileSize == 0 { + t.Fatal("segment file is unexpectedly empty") + } + + // Phase 2: Reopen the same segment (simulating post-recovery). + w2, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter (phase 2): %v", err) + } + defer w2.Close() + + // The writer must account for the preexisting data. + if w2.currentSizeBytes != fileSize { + t.Errorf("currentSizeBytes = %d, want %d (preexisting file size)", + w2.currentSizeBytes, fileSize) + } +} + +// TestAppend_ConcurrentWrites_NoDataRace validates concurrent WAL writes do not cause data races. +func TestAppend_ConcurrentWrites_NoDataRace(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + const numGoroutines = 50 + const recordsPerGoroutine = 20 + + var wg sync.WaitGroup + var errCount int64 + + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < recordsPerGoroutine; i++ { + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("goroutine-%d-key-%d", id, i)), + Value: []byte(fmt.Sprintf("value-%d", i)), + } + if err := w.Append(r); err != nil { + atomic.AddInt64(&errCount, 1) + } + } + }(g) + } + + wg.Wait() + + if errCount > 0 { + t.Errorf("%d Append calls failed under concurrency", errCount) + } +} + +// TestAppend_ConcurrentWrites_AllRecordsRecoverable verifies all concurrent writes are cleanly replayed. +func TestAppend_ConcurrentWrites_AllRecordsRecoverable(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numRecords = 100 + keys := make([]string, numRecords) + for i := 0; i < numRecords; i++ { + keys[i] = fmt.Sprintf("concurrent-key-%04d", i) + } + + var wg sync.WaitGroup + for i := 0; i < numRecords; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(keys[idx]), + Value: []byte("ok"), + } + _ = w.Append(r) + }(i) + } + wg.Wait() + w.Close() + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + for _, k := range keys { + if _, ok := mem.puts[k]; !ok { + t.Errorf("key %q missing after concurrent write + replay", k) + } + } +} + +// TestClose_IsIdempotent validates that closing multiple times behaves correctly. +func TestClose_IsIdempotent(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + if err := w.Close(); err != nil { + t.Errorf("first Close error: %v", err) + } + if err := w.Close(); err != nil { + t.Logf("second Close returned (expected nil): %v", err) + } +} + +// TestClose_BlocksUntilWorkerDone checks that Close blocks until background activities terminate. +func TestClose_BlocksUntilWorkerDone(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + done := make(chan struct{}) + go func() { + w.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Error("Close did not return within 3 seconds") + } +} + +// TestClose_SyncsDataToDisk verifies that closing the log writer flushes in-flight data. +func TestClose_SyncsDataToDisk(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + r := &Record{Opcode: OpcodePut, Key: []byte("durable"), Value: []byte("yes")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if string(mem.puts["durable"]) != "yes" { + t.Error("data written before Close was not recovered") + } +} + +// TestMaxSegmentSizeBytes_Is32MB checks the segment size boundary constant. +func TestMaxSegmentSizeBytes_Is32MB(t *testing.T) { + expected := int64(32 * 1024 * 1024) + if MaxSegmentSizeBytes != expected { + t.Errorf("MaxSegmentSizeBytes = %d, want %d (32 MiB)", MaxSegmentSizeBytes, expected) + } +} + +// TestBatchWorker_GroupCommit_AllTicketsSignalled verifies that all concurrent ticket requests receive replies. +func TestBatchWorker_GroupCommit_AllTicketsSignalled(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + const n = 200 + errs := make([]error, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("batch-key-%d", idx)), + Value: []byte("v"), + } + errs[idx] = w.Append(r) + }(i) + } + wg.Wait() + + for i, e := range errs { + if e != nil { + t.Errorf("Append[%d] returned error: %v", i, e) + } + } +} + +// TestMaxBatchSizeBytes_Is4MB checks the batch size limit constant. +func TestMaxBatchSizeBytes_Is4MB(t *testing.T) { + expected := int64(4 * 1024 * 1024) + if MaxBatchSizeBytes != expected { + t.Errorf("MaxBatchSizeBytes = %d, want %d (4 MiB)", MaxBatchSizeBytes, expected) + } +} + +// TestAppend_ConcurrentWithClose_NoHang verifies that goroutines calling Append +// while Close is invoked concurrently do not deadlock or panic. +func TestAppend_ConcurrentWithClose_NoHang(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numWriters = 50 + var wg sync.WaitGroup + + // Launch many concurrent writers. + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 20; j++ { + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("race-key-%d-%d", id, j)), + Value: []byte("v"), + } + _ = w.Append(r) // may return shutdown error, that's fine + } + }(i) + } + + // Close concurrently with writers. + done := make(chan struct{}) + go func() { + time.Sleep(1 * time.Millisecond) + w.Close() + close(done) + }() + + wg.Wait() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close or writers hung for more than 5 seconds") + } +} + +// TestAppend_AfterClose_ReturnsShutdownError checks the specific rejection path +// where Append observes isClosed=true under the read lock. +func TestAppend_AfterClose_ReturnsShutdownError(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + // Write one record to prove the writer works. + r := &Record{Opcode: OpcodePut, Key: []byte("before"), Value: []byte("close")} + if err := w.Append(r); err != nil { + t.Fatalf("Append before Close: %v", err) + } + + w.Close() + + // Multiple post-close appends should all return errors, never panic. + for i := 0; i < 10; i++ { + r := &Record{Opcode: OpcodePut, Key: []byte(fmt.Sprintf("post-%d", i)), Value: []byte("v")} + if err := w.Append(r); err == nil { + t.Errorf("Append[%d] after Close returned nil, expected error", i) + } + } +} + +// TestClose_ConcurrentCalls_NoPanic verifies that calling Close from multiple +// goroutines simultaneously does not panic or return inconsistent errors. +func TestClose_ConcurrentCalls_NoPanic(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + // Write some data first. + for i := 0; i < 10; i++ { + r := &Record{Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", i)), Value: []byte("v")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + } + + const numClosers = 10 + var wg sync.WaitGroup + for i := 0; i < numClosers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = w.Close() + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("concurrent Close calls hung for more than 5 seconds") + } +} + +// TestClose_DrainsInFlightTickets ensures that records already in the ingestion +// channel at the time of Close are still flushed and recoverable. +func TestClose_DrainsInFlightTickets(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numRecords = 50 + var wg sync.WaitGroup + errs := make([]error, numRecords) + + for i := 0; i < numRecords; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("drain-%04d", idx)), + Value: []byte("v"), + } + errs[idx] = w.Append(r) + }(i) + } + wg.Wait() + + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Count how many succeeded. + var successCount int + for _, e := range errs { + if e == nil { + successCount++ + } + } + if successCount == 0 { + t.Fatal("no records were successfully appended") + } + + // Verify that all successfully appended records are recoverable. + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + + for i, e := range errs { + key := fmt.Sprintf("drain-%04d", i) + if e == nil { + if _, ok := mem.puts[key]; !ok { + t.Errorf("record %q was accepted by Append but not recovered by Replay", key) + } + } + } +} + +// TestBatchWorker_ExitsCleanly_WhenChannelClosed verifies the batchWorker goroutine +// terminates cleanly when the ingestion channel is closed (the new for-range pattern). +func TestBatchWorker_ExitsCleanly_WhenChannelClosed(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + // Append a record to prove the worker is running. + r := &Record{Opcode: OpcodePut, Key: []byte("alive"), Value: []byte("yes")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + // Close should cause the channel to close and the worker to exit. + done := make(chan struct{}) + go func() { + w.Close() + close(done) + }() + + select { + case <-done: + // Worker exited cleanly. + case <-time.After(3 * time.Second): + t.Fatal("batchWorker did not exit within 3 seconds after channel close") + } + + // Confirm the data is durable. + mem := newMockRecordConsumer() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if string(mem.puts["alive"]) != "yes" { + t.Error("record written before Close was not recovered") + } +} + +// TestAppend_ConcurrentWritesDuringClose_AllResolve verifies that every goroutine +// that calls Append gets a definitive result (success or error), even when Close +// is called concurrently. No goroutine should hang. +func TestAppend_ConcurrentWritesDuringClose_AllResolve(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numWriters = 100 + results := make(chan error, numWriters) + + var wg sync.WaitGroup + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("resolve-%d", id)), + Value: []byte("v"), + } + results <- w.Append(r) + }(i) + } + + // Close while writers are still racing. + go func() { + time.Sleep(500 * time.Microsecond) + w.Close() + }() + + // All writers must eventually return. + allDone := make(chan struct{}) + go func() { + wg.Wait() + close(allDone) + }() + + select { + case <-allDone: + case <-time.After(5 * time.Second): + t.Fatal("not all Append goroutines resolved within 5 seconds") + } + + close(results) + var succeeded, failed int + for err := range results { + if err == nil { + succeeded++ + } else { + failed++ + } + } + t.Logf("concurrent close test: %d succeeded, %d rejected", succeeded, failed) + + if succeeded+failed != numWriters { + t.Errorf("expected %d total results, got %d", numWriters, succeeded+failed) + } +}