From d32e3dbb7fc6abaca59127a87635c6f3e27a37e8 Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 02:26:14 +0530 Subject: [PATCH 01/18] Implement The Binary Record Format and made Marshal and Unmarshal Functions --- internal/storage/wal/.gitkeep | 0 internal/storage/wal/errors.go | 8 ++++ internal/storage/wal/record.go | 72 ++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+) delete mode 100644 internal/storage/wal/.gitkeep create mode 100644 internal/storage/wal/errors.go create mode 100644 internal/storage/wal/record.go diff --git a/internal/storage/wal/.gitkeep b/internal/storage/wal/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go new file mode 100644 index 0000000..9812bb6 --- /dev/null +++ b/internal/storage/wal/errors.go @@ -0,0 +1,8 @@ +package wal + +import "errors" + +var ( + ErrInvalidCRC = errors.New("corrupt wal record: crc32 mismatch") + ErrTruncated = errors.New("corrupt wal record: truncated payload") +) diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go new file mode 100644 index 0000000..e7b00b1 --- /dev/null +++ b/internal/storage/wal/record.go @@ -0,0 +1,72 @@ +package wal + +import ( + "encoding/binary" + "hash/crc32" +) + +const ( + OpcodePut uint8 = 0 + OpcodeDelete uint8 = 1 +) + +type Record struct { + Opcode uint8 + Key []byte + Value []byte +} + +func (record *Record) Marshal() []byte { + metadataAndDataSize := 3 + len(record.Key) + len(record.Value) + totalFrameSizeBytes := 8 + metadataAndDataSize + + frameBuffer := make([]byte, totalFrameSizeBytes) + + binary.LittleEndian.PutUint32(frameBuffer[4:8], uint32(totalFrameSizeBytes)) + frameBuffer[8] = record.Opcode + binary.LittleEndian.PutUint16(frameBuffer[9:11], uint16(len(record.Key))) + + copy(frameBuffer[11:], record.Key) + copy(frameBuffer[11+len(record.Key):], record.Value) + + calculatedChecksum := crc32.ChecksumIEEE(frameBuffer[4:]) + binary.LittleEndian.PutUint32(frameBuffer[0:4], calculatedChecksum) + + return frameBuffer +} + +func UnmarshalRecord(frameData []byte) (*Record, error) { + if len(frameData) < 11 { + return nil, ErrTruncated + } + + storedChecksum := binary.LittleEndian.Uint32(frameData[0:4]) + calculatedChecksum := crc32.ChecksumIEEE(frameData[4:]) + + if storedChecksum != calculatedChecksum { + return nil, ErrInvalidCRC + } + + extractedOpcode := frameData[8] + extractedKeyLength := binary.LittleEndian.Uint16(frameData[9:11]) + + if len(frameData) < int(11+extractedKeyLength) { + return nil, ErrTruncated + } + + extractedKey := make([]byte, extractedKeyLength) + copy(extractedKey, frameData[11:11+extractedKeyLength]) + + extractedValueLength := len(frameData) - int(11+extractedKeyLength) + var extractedValue []byte + if extractedValueLength > 0 { + extractedValue = make([]byte, extractedValueLength) + copy(extractedValue, frameData[11+extractedKeyLength:]) + } + + return &Record{ + Opcode: extractedOpcode, + Key: extractedKey, + Value: extractedValue, + }, nil +} From 6bd3d054aa4e5ba44df59106db0d316b72c3dc85 Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 15:53:15 +0530 Subject: [PATCH 02/18] Implement the Batch Group Write Worker logic --- internal/storage/wal/writer.go | 105 +++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 internal/storage/wal/writer.go diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go new file mode 100644 index 0000000..441f58b --- /dev/null +++ b/internal/storage/wal/writer.go @@ -0,0 +1,105 @@ +package wal + +import ( + "log/slog" + "os" + "sync" +) + +type commitTicket struct { + frameData []byte + resultChan chan error +} + +type LogWriter struct { + activeFile *os.File + ingestionChannel chan *commitTicket + shutdownSignal chan struct{} + workerWaitGroup sync.WaitGroup +} + +func NewLogWriter(filePath string) (*LogWriter, error) { + file, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, err + } + + writer := &LogWriter{ + activeFile: file, + ingestionChannel: make(chan *commitTicket, 10000), + shutdownSignal: make(chan struct{}), + } + + writer.workerWaitGroup.Add(1) + go writer.batchWorker() + + return writer, nil +} + +func (writer *LogWriter) Append(record *Record) error { + frame := record.Marshal() + + ticket := &commitTicket{ + frameData: frame, + resultChan: make(chan error, 1), + } + + slog.Debug("network thread: dropping record into ingestion channel", "frame_size", len(frame)) + + select { + case writer.ingestionChannel <- ticket: + return <-ticket.resultChan + case <-writer.shutdownSignal: + return os.ErrClosed + } +} + +func (writer *LogWriter) batchWorker() { + defer writer.workerWaitGroup.Done() + + var commitBatch []*commitTicket + var writeBuffer []byte + + for { + select { + case <-writer.shutdownSignal: + return + + case firstTicket := <-writer.ingestionChannel: + commitBatch = append(commitBatch[:0], firstTicket) + writeBuffer = append(writeBuffer[:0], firstTicket.frameData...) + + pendingWrites := len(writer.ingestionChannel) + for i := 0; i < pendingWrites; i++ { + ticket := <-writer.ingestionChannel + commitBatch = append(commitBatch, ticket) + writeBuffer = append(writeBuffer, ticket.frameData...) + } + + slog.Debug("batch worker: executing group commit", + "batch_size", len(commitBatch), + "total_bytes", len(writeBuffer)) + + _, err := writer.activeFile.Write(writeBuffer) + if err == nil { + err = writer.activeFile.Sync() + } + + if err != nil { + slog.Debug("batch worker: fsync failed", "error", err) + } else { + slog.Debug("batch worker: fsync successful") + } + + for _, ticket := range commitBatch { + ticket.resultChan <- err + } + } + } +} + +func (writer *LogWriter) Close() error { + close(writer.shutdownSignal) + writer.workerWaitGroup.Wait() + return writer.activeFile.Close() +} From e8243954dfde4ee73198bb44c8c186ba8878adb2 Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 16:33:25 +0530 Subject: [PATCH 03/18] Implement segment rotation logic --- internal/storage/wal/writer.go | 56 ++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 441f58b..c44afc3 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -1,41 +1,75 @@ package wal import ( + "fmt" "log/slog" "os" + "path/filepath" "sync" ) +const MaxSegmentSizeBytes int64 = 32 * 1024 * 1024 + type commitTicket struct { frameData []byte resultChan chan error } type LogWriter struct { + directory string activeFile *os.File + currentSegmentID int + currentSizeBytes int64 + ingestionChannel chan *commitTicket shutdownSignal chan struct{} workerWaitGroup sync.WaitGroup } -func NewLogWriter(filePath string) (*LogWriter, error) { - file, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return nil, err +func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { + if err := os.MkdirAll(directory, 0755); err != nil { + return nil, fmt.Errorf("failed to initialize WAL directory structure at %s: %w", directory, err) } writer := &LogWriter{ - activeFile: file, + directory: directory, + currentSegmentID: nextSegmentID, ingestionChannel: make(chan *commitTicket, 10000), shutdownSignal: make(chan struct{}), } + if err := writer.rotateActiveFile(); err != nil { + return nil, err + } + writer.workerWaitGroup.Add(1) go writer.batchWorker() return writer, nil } +func (writer *LogWriter) rotateActiveFile() error { + if writer.activeFile != nil { + if err := writer.activeFile.Sync(); err != nil { + return fmt.Errorf("failed to sync WAL segment %d during rotation: %w", writer.currentSegmentID, err) + } + if err := writer.activeFile.Close(); err != nil { + return fmt.Errorf("failed to close WAL segment %d during rotation: %w", writer.currentSegmentID, err) + } + writer.currentSegmentID++ + } + + segmentPath := filepath.Join(writer.directory, fmt.Sprintf("%06d.wal", writer.currentSegmentID)) + file, err := os.OpenFile(segmentPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open new WAL segment %s: %w", segmentPath, err) + } + + writer.activeFile = file + writer.currentSizeBytes = 0 + return nil +} + func (writer *LogWriter) Append(record *Record) error { frame := record.Marshal() @@ -50,7 +84,7 @@ func (writer *LogWriter) Append(record *Record) error { case writer.ingestionChannel <- ticket: return <-ticket.resultChan case <-writer.shutdownSignal: - return os.ErrClosed + return fmt.Errorf("database engine is currently shutting down, write rejected") } } @@ -76,6 +110,15 @@ func (writer *LogWriter) batchWorker() { writeBuffer = append(writeBuffer, ticket.frameData...) } + if writer.currentSizeBytes+int64(len(writeBuffer)) > MaxSegmentSizeBytes { + if err := writer.rotateActiveFile(); err != nil { + for _, ticket := range commitBatch { + ticket.resultChan <- err + } + continue + } + } + slog.Debug("batch worker: executing group commit", "batch_size", len(commitBatch), "total_bytes", len(writeBuffer)) @@ -83,6 +126,7 @@ func (writer *LogWriter) batchWorker() { _, err := writer.activeFile.Write(writeBuffer) if err == nil { err = writer.activeFile.Sync() + writer.currentSizeBytes += int64(len(writeBuffer)) } if err != nil { From bd89933db6c1deeac8f39230f99785d990756cee Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 18:33:13 +0530 Subject: [PATCH 04/18] Implement the replay logic to rebuild the memtable on server crash --- internal/storage/wal/recovery.go | 137 +++++++++++++++++++++++++++++++ internal/storage/wal/writer.go | 6 +- 2 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 internal/storage/wal/recovery.go diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go new file mode 100644 index 0000000..e464fc5 --- /dev/null +++ b/internal/storage/wal/recovery.go @@ -0,0 +1,137 @@ +package wal + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "sort" +) + +type MemTable interface { + Put(key, value []byte) error + Delete(key []byte) error +} + +func Replay(directory string, engine MemTable) (int, error) { + slog.Debug("starting WAL recovery sequence", "directory", directory) + + entries, err := os.ReadDir(directory) + if err != nil { + if os.IsNotExist(err) { + slog.Debug("WAL directory does not exist, starting fresh", "directory", directory) + return 1, nil + } + return 0, fmt.Errorf("failed to scan WAL directory during recovery boot: %w", err) + } + + var walFiles []string + for _, entry := range entries { + if !entry.IsDir() && filepath.Ext(entry.Name()) == ".wal" { + walFiles = append(walFiles, entry.Name()) + } + } + + slog.Debug("found WAL segments for replay", "count", len(walFiles)) + sort.Strings(walFiles) + + highestSegmentID := 0 + + for _, fileName := range walFiles { + var segmentID int + fmt.Sscanf(fileName, "%d.wal", &segmentID) + if segmentID > highestSegmentID { + highestSegmentID = segmentID + } + + filePath := filepath.Join(directory, fileName) + slog.Debug("replaying WAL segment", "segment_id", segmentID, "file", fileName) + + if err := replayFile(filePath, engine); err != nil { + return 0, fmt.Errorf("critical failure while replaying segment %s: %w", fileName, err) + } + } + + if highestSegmentID == 0 { + highestSegmentID = 1 + } + + slog.Debug("WAL recovery complete", "highest_segment_id", highestSegmentID) + return highestSegmentID, nil +} + +func replayFile(filePath string, engine MemTable) error { + file, err := os.OpenFile(filePath, os.O_RDWR, 0644) + if err != nil { + return fmt.Errorf("unable to open WAL segment for reading: %w", err) + } + defer file.Close() + + var validBytes int64 = 0 + var recordsRecovered int + headerBuffer := make([]byte, 8) + + for { + _, err := io.ReadFull(file, headerBuffer) + if err != nil { + if errors.Is(err, io.EOF) { + slog.Debug("reached clean EOF for WAL segment", + "file", filepath.Base(filePath), + "records_recovered", recordsRecovered) + break + } + if errors.Is(err, io.ErrUnexpectedEOF) { + slog.Debug("unexpected EOF in header, truncating segment", + "file", filepath.Base(filePath), + "valid_bytes", validBytes) + return file.Truncate(validBytes) + } + return fmt.Errorf("unexpected disk error reading frame header: %w", err) + } + + totalFrameSizeBytes := binary.LittleEndian.Uint32(headerBuffer[4:8]) + payloadSizeBytes := totalFrameSizeBytes - 8 + + payloadBuffer := make([]byte, payloadSizeBytes) + _, err = io.ReadFull(file, payloadBuffer) + if err != nil { + slog.Debug("unexpected EOF in payload, truncating segment", + "file", filepath.Base(filePath), + "valid_bytes", validBytes) + return file.Truncate(validBytes) + } + + fullFrame := append(headerBuffer, payloadBuffer...) + + record, err := UnmarshalRecord(fullFrame) + if err != nil { + if errors.Is(err, ErrInvalidCRC) || errors.Is(err, ErrTruncated) { + slog.Debug("corrupted frame detected, truncating segment", + "file", filepath.Base(filePath), + "valid_bytes", validBytes, + "error", err) + return file.Truncate(validBytes) + } + return fmt.Errorf("failed to decode valid frame payload: %w", err) + } + + switch record.Opcode { + case OpcodePut: + if putErr := engine.Put(record.Key, record.Value); putErr != nil { + return fmt.Errorf("memtable rejected recovered put operation: %w", putErr) + } + case OpcodeDelete: + if delErr := engine.Delete(record.Key); delErr != nil { + return fmt.Errorf("memtable rejected recovered delete operation: %w", delErr) + } + } + + validBytes += int64(totalFrameSizeBytes) + recordsRecovered++ + } + + return nil +} diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index c44afc3..3f86426 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -145,5 +145,9 @@ func (writer *LogWriter) batchWorker() { func (writer *LogWriter) Close() error { close(writer.shutdownSignal) writer.workerWaitGroup.Wait() - return writer.activeFile.Close() + if writer.activeFile != nil { + _ = writer.activeFile.Sync() + return writer.activeFile.Close() + } + return nil } From 91c9f7db21b40d387e39770d524b453c0b2199ec Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 19:19:20 +0530 Subject: [PATCH 05/18] Add unit test cases for WAL phase and added concurrency tests for skiplist --- internal/storage/memtable/skiplist_test.go | 80 ++++ internal/storage/wal/errors.go | 1 + internal/storage/wal/record.go | 6 +- internal/storage/wal/record_test.go | 259 ++++++++++++ internal/storage/wal/recovery_test.go | 400 ++++++++++++++++++ internal/storage/wal/writer.go | 22 +- internal/storage/wal/writer_test.go | 453 +++++++++++++++++++++ 7 files changed, 1211 insertions(+), 10 deletions(-) create mode 100644 internal/storage/wal/record_test.go create mode 100644 internal/storage/wal/recovery_test.go create mode 100644 internal/storage/wal/writer_test.go diff --git a/internal/storage/memtable/skiplist_test.go b/internal/storage/memtable/skiplist_test.go index 20d24b0..ef79d65 100644 --- a/internal/storage/memtable/skiplist_test.go +++ b/internal/storage/memtable/skiplist_test.go @@ -2,6 +2,8 @@ package memtable import ( "bytes" + "fmt" + "sync" "testing" ) @@ -175,6 +177,84 @@ func TestSkipList_EmptyAndNil(t *testing.T) { } } +// TestSkipList_StrictConcurrency verifies that a single writer goroutine and +// multiple concurrent reader goroutines operating on the same key never observe +// a corrupted or malformed value, confirming read-write lock correctness. +func TestSkipList_StrictConcurrency(t *testing.T) { + skipList := NewSkipList(100000, 12) + var waitGroup sync.WaitGroup + key := []byte("shared-key") + + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + for i := 0; i < 1000; i++ { + val := []byte(fmt.Sprintf("val-%d", i)) + _ = skipList.Put(key, val) + } + }() + + for r := 0; r < 5; r++ { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + for i := 0; i < 1000; i++ { + val, err := skipList.Get(key) + if err != nil && err != ErrKeyNotFound { + t.Errorf("unexpected error on Get: %v", err) + } + if err == nil { + if !bytes.HasPrefix(val, []byte("val-")) { + t.Errorf("corrupted value: %s", val) + } + } + } + }() + } + + waitGroup.Wait() +} + +// TestSkipList_Concurrency is a broad stress test that runs concurrent Puts, +// Deletes, and Iterator traversals across disjoint key ranges simultaneously, +// verifying that no deadlock, panic, or data corruption occurs under contention. +func TestSkipList_Concurrency(t *testing.T) { + skipList := NewSkipList(100000, 12) + var waitGroup sync.WaitGroup + + for i := 0; i < 100; i++ { + waitGroup.Add(1) + go func(id int) { + defer waitGroup.Done() + key := []byte(fmt.Sprintf("key-%03d", id)) + val := []byte(fmt.Sprintf("val-%03d", id)) + _ = skipList.Put(key, val) + }(i) + } + + for i := 100; i < 200; i++ { + waitGroup.Add(1) + go func(id int) { + defer waitGroup.Done() + key := []byte(fmt.Sprintf("key-%03d", id)) + _ = skipList.Delete(key) + }(i) + } + + for i := 0; i < 20; i++ { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + iterator := skipList.NewIterator() + for iterator.Valid() { + _, _, _ = iterator.Next() + } + }() + } + + waitGroup.Wait() +} + // TestSkipList_SortedOrder verifies that the iterator always returns keys in // ascending lexicographic order, regardless of insertion order. func TestSkipList_SortedOrder(t *testing.T) { diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go index 9812bb6..dff1f31 100644 --- a/internal/storage/wal/errors.go +++ b/internal/storage/wal/errors.go @@ -5,4 +5,5 @@ import "errors" var ( ErrInvalidCRC = errors.New("corrupt wal record: crc32 mismatch") ErrTruncated = errors.New("corrupt wal record: truncated payload") + ErrEmptyKey = errors.New("wal record rejected: key must not be empty") ) diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go index e7b00b1..06a88d6 100644 --- a/internal/storage/wal/record.go +++ b/internal/storage/wal/record.go @@ -50,14 +50,14 @@ func UnmarshalRecord(frameData []byte) (*Record, error) { extractedOpcode := frameData[8] extractedKeyLength := binary.LittleEndian.Uint16(frameData[9:11]) - if len(frameData) < int(11+extractedKeyLength) { + if len(frameData) < 11+int(extractedKeyLength) { return nil, ErrTruncated } extractedKey := make([]byte, extractedKeyLength) - copy(extractedKey, frameData[11:11+extractedKeyLength]) + copy(extractedKey, frameData[11:11+int(extractedKeyLength)]) - extractedValueLength := len(frameData) - int(11+extractedKeyLength) + extractedValueLength := len(frameData) - (11 + int(extractedKeyLength)) var extractedValue []byte if extractedValueLength > 0 { extractedValue = make([]byte, extractedValueLength) diff --git a/internal/storage/wal/record_test.go b/internal/storage/wal/record_test.go new file mode 100644 index 0000000..1b924a2 --- /dev/null +++ b/internal/storage/wal/record_test.go @@ -0,0 +1,259 @@ +package wal + +import ( + "bytes" + "encoding/binary" + "hash/crc32" + "testing" +) + +func buildValidFrame(r *Record) []byte { + return r.Marshal() +} + +func TestMarshal_FrameLayout(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} + frame := r.Marshal() + + wantSize := 8 + 3 + len(r.Key) + len(r.Value) + if len(frame) != wantSize { + t.Fatalf("frame length = %d, want %d", len(frame), wantSize) + } + + storedSize := binary.LittleEndian.Uint32(frame[4:8]) + if int(storedSize) != wantSize { + t.Errorf("stored size field = %d, want %d", storedSize, wantSize) + } + + if frame[8] != OpcodePut { + t.Errorf("opcode byte = %d, want %d", frame[8], OpcodePut) + } + + storedKeyLen := binary.LittleEndian.Uint16(frame[9:11]) + if int(storedKeyLen) != len(r.Key) { + t.Errorf("stored key length = %d, want %d", storedKeyLen, len(r.Key)) + } + + if !bytes.Equal(frame[11:11+len(r.Key)], r.Key) { + t.Error("key bytes mismatch in frame") + } + + if !bytes.Equal(frame[11+len(r.Key):], r.Value) { + t.Error("value bytes mismatch in frame") + } +} + +func TestMarshal_CRCCoversPayload(t *testing.T) { + r := &Record{Opcode: OpcodeDelete, Key: []byte("k"), Value: nil} + frame := r.Marshal() + + storedCRC := binary.LittleEndian.Uint32(frame[0:4]) + if storedCRC != crc32.ChecksumIEEE(frame[4:]) { + t.Errorf("CRC mismatch: stored=%d calculated=%d", storedCRC, crc32.ChecksumIEEE(frame[4:])) + } +} + +func TestMarshal_OpcodeDelete(t *testing.T) { + r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} + frame := r.Marshal() + if frame[8] != OpcodeDelete { + t.Errorf("opcode = %d, want OpcodeDelete (%d)", frame[8], OpcodeDelete) + } +} + +func TestMarshal_ZeroLengthKey(t *testing.T) { + cases := []struct { + name string + key []byte + }{ + {"empty", []byte{}}, + {"nil", nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + frame := (&Record{Opcode: OpcodePut, Key: tc.key, Value: []byte("v")}).Marshal() + if keyLen := binary.LittleEndian.Uint16(frame[9:11]); keyLen != 0 { + t.Errorf("key length = %d, want 0", keyLen) + } + }) + } +} + +func TestMarshal_ZeroLengthValue(t *testing.T) { + cases := []struct { + name string + value []byte + }{ + {"empty", []byte{}}, + {"nil", nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: tc.value} + frame := r.Marshal() + wantSize := 8 + 3 + len(r.Key) + if len(frame) != wantSize { + t.Errorf("frame length = %d, want %d", len(frame), wantSize) + } + }) + } +} + +func TestMarshal_LargePayload(t *testing.T) { + key := bytes.Repeat([]byte("k"), 1024) + value := bytes.Repeat([]byte("v"), 4096) + r := &Record{Opcode: OpcodePut, Key: key, Value: value} + frame := r.Marshal() + wantSize := 8 + 3 + len(key) + len(value) + if len(frame) != wantSize { + t.Errorf("frame size = %d, want %d", len(frame), wantSize) + } + storedCRC := binary.LittleEndian.Uint32(frame[0:4]) + if storedCRC != crc32.ChecksumIEEE(frame[4:]) { + t.Error("CRC invalid for large payload") + } +} + +func TestMarshal_BinaryKeyAndValue(t *testing.T) { + key := []byte{0x00, 0xFF, 0x7F, 0x80} + value := []byte{0x01, 0x02, 0x03} + r := &Record{Opcode: OpcodePut, Key: key, Value: value} + frame := r.Marshal() + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("UnmarshalRecord failed: %v", err) + } + if !bytes.Equal(recovered.Key, key) { + t.Errorf("key mismatch: got %v, want %v", recovered.Key, key) + } + if !bytes.Equal(recovered.Value, value) { + t.Errorf("value mismatch: got %v, want %v", recovered.Value, value) + } +} + +func TestMarshal_Idempotent(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("idempotent"), Value: []byte("yes")} + if !bytes.Equal(r.Marshal(), r.Marshal()) { + t.Error("Marshal is not idempotent for the same record") + } +} + +func TestUnmarshal_RoundTrip_Put(t *testing.T) { + original := &Record{Opcode: OpcodePut, Key: []byte("name"), Value: []byte("penguin")} + frame := buildValidFrame(original) + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if recovered.Opcode != original.Opcode { + t.Errorf("opcode: got %d, want %d", recovered.Opcode, original.Opcode) + } + if !bytes.Equal(recovered.Key, original.Key) { + t.Errorf("key: got %q, want %q", recovered.Key, original.Key) + } + if !bytes.Equal(recovered.Value, original.Value) { + t.Errorf("value: got %q, want %q", recovered.Value, original.Value) + } +} + +func TestUnmarshal_RoundTrip_Delete(t *testing.T) { + original := &Record{Opcode: OpcodeDelete, Key: []byte("to-remove"), Value: nil} + frame := buildValidFrame(original) + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if recovered.Opcode != OpcodeDelete { + t.Errorf("opcode: got %d, want OpcodeDelete", recovered.Opcode) + } + if !bytes.Equal(recovered.Key, original.Key) { + t.Errorf("key mismatch") + } +} + +func TestUnmarshal_TruncatedFrame_TooShort(t *testing.T) { + cases := [][]byte{ + {}, + {0x01}, + {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}, + } + for _, data := range cases { + _, err := UnmarshalRecord(data) + if err != ErrTruncated { + t.Errorf("input len=%d: expected ErrTruncated, got %v", len(data), err) + } + } +} + +func TestUnmarshal_MinimalFrame_EmptyKeyNoValue(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: nil} + frame := r.Marshal() + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error for minimal frame: %v", err) + } + if len(recovered.Key) != 0 { + t.Errorf("expected empty key, got %v", recovered.Key) + } +} + +func TestUnmarshal_CRCCorruption(t *testing.T) { + cases := []struct { + name string + corruptFn func([]byte) + }{ + {"payload byte", func(f []byte) { f[12] ^= 0x01 }}, + {"crc field", func(f []byte) { f[0] ^= 0xFF }}, + {"size field", func(f []byte) { binary.LittleEndian.PutUint32(f[4:8], 9999) }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: []byte("val")} + frame := r.Marshal() + tc.corruptFn(frame) + _, err := UnmarshalRecord(frame) + if err != ErrInvalidCRC { + t.Errorf("expected ErrInvalidCRC, got %v", err) + } + }) + } +} + +func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { + r := &Record{Opcode: OpcodePut, Key: []byte("ab"), Value: []byte("v")} + frame := r.Marshal() + + binary.LittleEndian.PutUint16(frame[9:11], 65535) + newCRC := crc32.ChecksumIEEE(frame[4:]) + binary.LittleEndian.PutUint32(frame[0:4], newCRC) + + _, err := UnmarshalRecord(frame) + if err != ErrTruncated { + t.Errorf("expected ErrTruncated when keyLen > frame size, got %v", err) + } +} + +func TestUnmarshal_NilValueForDeleteRecord(t *testing.T) { + r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} + frame := r.Marshal() + + recovered, err := UnmarshalRecord(frame) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if recovered.Value != nil { + t.Errorf("expected nil Value for delete record, got %v", recovered.Value) + } +} + +func TestOpcodeValues(t *testing.T) { + if OpcodePut != 0 { + t.Errorf("OpcodePut = %d, want 0", OpcodePut) + } + if OpcodeDelete != 1 { + t.Errorf("OpcodeDelete = %d, want 1", OpcodeDelete) + } +} diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go new file mode 100644 index 0000000..96f17c4 --- /dev/null +++ b/internal/storage/wal/recovery_test.go @@ -0,0 +1,400 @@ +package wal + +import ( + "encoding/binary" + "fmt" + "hash/crc32" + "os" + "path/filepath" + "testing" +) + +type mockMemTable struct { + puts map[string][]byte + deletes []string + putErr error + delErr error +} + +func newMockMemTable() *mockMemTable { + return &mockMemTable{puts: make(map[string][]byte)} +} + +func (m *mockMemTable) Put(key, value []byte) error { + if m.putErr != nil { + return m.putErr + } + m.puts[string(key)] = value + return nil +} + +func (m *mockMemTable) Delete(key []byte) error { + if m.delErr != nil { + return m.delErr + } + m.deletes = append(m.deletes, string(key)) + return nil +} + +func writeRecordsToFile(t *testing.T, path string, records []*Record) { + t.Helper() + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + t.Fatalf("writeRecordsToFile open: %v", err) + } + defer f.Close() + for _, r := range records { + if _, err := f.Write(r.Marshal()); err != nil { + t.Fatalf("writeRecordsToFile write: %v", err) + } + } +} + +func segmentPath(dir string, id int) string { + return filepath.Join(dir, fmt.Sprintf("%06d.wal", id)) +} + +func TestReplay_NonExistentDirectory_ReturnsFreshSegmentID(t *testing.T) { + dir := filepath.Join(t.TempDir(), "no-such-wal") + mem := newMockMemTable() + + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } + if len(mem.puts) != 0 || len(mem.deletes) != 0 { + t.Error("memtable should be empty for a fresh start") + } +} + +func TestReplay_EmptyDirectory_ReturnsFreshSegmentID(t *testing.T) { + dir := t.TempDir() + mem := newMockMemTable() + + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +func TestReplay_NonWALFilesIgnored(t *testing.T) { + dir := t.TempDir() + for _, name := range []string{"data.sst", "manifest", "000001.log", "LOCK"} { + if err := os.WriteFile(filepath.Join(dir, name), []byte("data"), 0644); err != nil { + t.Fatal(err) + } + } + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +func TestReplay_SingleSegment_AllPuts(t *testing.T) { + dir := t.TempDir() + records := []*Record{ + {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("1")}, + {Opcode: OpcodePut, Key: []byte("b"), Value: []byte("2")}, + {Opcode: OpcodePut, Key: []byte("c"), Value: []byte("3")}, + } + writeRecordsToFile(t, segmentPath(dir, 1), records) + + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } + for _, r := range records { + if string(mem.puts[string(r.Key)]) != string(r.Value) { + t.Errorf("key %q: got value %q, want %q", r.Key, mem.puts[string(r.Key)], r.Value) + } + } +} + +func TestReplay_SingleSegment_Deletes(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("x"), Value: []byte("val")}, + {Opcode: OpcodeDelete, Key: []byte("x")}, + }) + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(mem.deletes) != 1 || mem.deletes[0] != "x" { + t.Errorf("expected one delete for 'x', got %v", mem.deletes) + } +} + +func TestReplay_MultipleSegments_ReplayedInOrder(t *testing.T) { + dir := t.TempDir() + + writeRecordsToFile(t, segmentPath(dir, 3), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg3")}, + }) + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg1")}, + }) + writeRecordsToFile(t, segmentPath(dir, 2), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg2")}, + }) + + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(mem.puts["k"]) != "seg3" { + t.Errorf("key 'k' value = %q, want 'seg3'", mem.puts["k"]) + } + if nextID != 3 { + t.Errorf("nextID = %d, want 3", nextID) + } +} + +func TestReplay_ReturnsHighestSegmentID(t *testing.T) { + dir := t.TempDir() + for _, id := range []int{1, 5, 10} { + writeRecordsToFile(t, segmentPath(dir, id), []*Record{ + {Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", id)), Value: []byte("v")}, + }) + } + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 10 { + t.Errorf("nextID = %d, want 10", nextID) + } +} + +func TestReplay_UnknownOpcode_Ignored(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: 99, Key: []byte("k"), Value: []byte("v")}, + }) + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Errorf("unexpected error for unknown opcode: %v", err) + } + if _, ok := mem.puts["k"]; ok { + t.Error("unknown opcode record should not be applied to memtable") + } +} + +func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + validBytes := good.Marshal() + writeRecordsToFile(t, path, []*Record{good}) + + badFrame := (&Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("x")}).Marshal() + badFrame[12] ^= 0xFF + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write(badFrame) + f.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied to memtable") + } + if _, ok := mem.puts["bad"]; ok { + t.Error("corrupt record was applied to memtable") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + writeRecordsToFile(t, path, []*Record{ + {Opcode: OpcodePut, Key: []byte("ok"), Value: []byte("v")}, + }) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0xFF}) + f.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["ok"]; !ok { + t.Error("valid record was not applied") + } +} + +func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + writeRecordsToFile(t, path, []*Record{ + {Opcode: OpcodePut, Key: []byte("safe"), Value: []byte("data")}, + }) + + fakeKey := []byte("payload-cut") + fakeSizeBytes := uint32(8 + 3 + len(fakeKey) + 100) + hdr := make([]byte, 8) + binary.LittleEndian.PutUint32(hdr[4:8], fakeSizeBytes) + binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write(hdr) + f.Write(fakeKey[:3]) + f.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["safe"]; !ok { + t.Error("valid record before truncation point was not applied") + } +} + +func TestReplay_EmptySegmentFile_NoError(t *testing.T) { + dir := t.TempDir() + os.WriteFile(segmentPath(dir, 1), []byte{}, 0644) + + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error for empty WAL file: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} + +func TestReplay_MemTablePutError_PropagatesError(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")}, + }) + + mem := newMockMemTable() + mem.putErr = fmt.Errorf("memtable full") + + if _, err := Replay(dir, mem); err == nil { + t.Fatal("expected error when memtable.Put fails, got nil") + } +} + +func TestReplay_MemTableDeleteError_PropagatesError(t *testing.T) { + dir := t.TempDir() + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodeDelete, Key: []byte("k")}, + }) + + mem := newMockMemTable() + mem.delErr = fmt.Errorf("read-only memtable") + + if _, err := Replay(dir, mem); err == nil { + t.Fatal("expected error when memtable.Delete fails, got nil") + } +} + +func TestReplay_MixedPutsAndDeletes_CorrectOrder(t *testing.T) { + dir := t.TempDir() + mem := newMockMemTable() + + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("1")}, + {Opcode: OpcodePut, Key: []byte("b"), Value: []byte("2")}, + {Opcode: OpcodeDelete, Key: []byte("a")}, + {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("3")}, + }) + + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if string(mem.puts["a"]) != "3" { + t.Errorf("key 'a' value = %q, want '3'", mem.puts["a"]) + } + if string(mem.puts["b"]) != "2" { + t.Errorf("key 'b' value = %q, want '2'", mem.puts["b"]) + } + + found := false + for _, d := range mem.deletes { + if d == "a" { + found = true + } + } + if !found { + t.Error("expected a Delete('a') call to memtable") + } +} + +func TestReplay_SegmentsSortedNumerically(t *testing.T) { + dir := t.TempDir() + + for i := 1; i <= 12; i++ { + writeRecordsToFile(t, segmentPath(dir, i), []*Record{ + {Opcode: OpcodePut, Key: []byte("seq"), Value: []byte(fmt.Sprintf("%d", i))}, + }) + } + + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 12 { + t.Errorf("nextID = %d, want 12", nextID) + } + if string(mem.puts["seq"]) != "12" { + t.Errorf("value for 'seq' = %q, want '12'", mem.puts["seq"]) + } +} + +func TestReplay_SubdirectoriesAreIgnored(t *testing.T) { + dir := t.TempDir() + subDir := filepath.Join(dir, "archive.wal") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatal(err) + } + writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ + {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")}, + }) + + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } +} diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 3f86426..8c686ce 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -24,6 +24,7 @@ type LogWriter struct { ingestionChannel chan *commitTicket shutdownSignal chan struct{} workerWaitGroup sync.WaitGroup + closeOnce sync.Once } func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { @@ -71,6 +72,10 @@ func (writer *LogWriter) rotateActiveFile() error { } func (writer *LogWriter) Append(record *Record) error { + if len(record.Key) == 0 { + return ErrEmptyKey + } + frame := record.Marshal() ticket := &commitTicket{ @@ -143,11 +148,14 @@ func (writer *LogWriter) batchWorker() { } func (writer *LogWriter) Close() error { - close(writer.shutdownSignal) - writer.workerWaitGroup.Wait() - if writer.activeFile != nil { - _ = writer.activeFile.Sync() - return writer.activeFile.Close() - } - return nil + var closeErr error + writer.closeOnce.Do(func() { + close(writer.shutdownSignal) + writer.workerWaitGroup.Wait() + if writer.activeFile != nil { + _ = writer.activeFile.Sync() + closeErr = writer.activeFile.Close() + } + }) + return closeErr } diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go new file mode 100644 index 0000000..09a8961 --- /dev/null +++ b/internal/storage/wal/writer_test.go @@ -0,0 +1,453 @@ +package wal + +import ( + "fmt" + "os" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestNewLogWriter_CreatesDirectory(t *testing.T) { + base := t.TempDir() + dir := fmt.Sprintf("%s/nested/wal", base) + + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer w.Close() + + if _, err := os.Stat(dir); os.IsNotExist(err) { + t.Error("WAL directory was not created") + } +} + +func TestNewLogWriter_CreatesFirstSegmentFile(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer w.Close() + + if _, err := os.Stat(segmentPath(dir, 1)); os.IsNotExist(err) { + t.Error("first segment file 000001.wal was not created") + } +} + +func TestNewLogWriter_RespectsStartSegmentID(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 7) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer w.Close() + + if _, err := os.Stat(segmentPath(dir, 7)); os.IsNotExist(err) { + t.Error("segment 000007.wal was not created for startSegmentID=7") + } +} + +func TestNewLogWriter_InvalidDirectory_ReturnsError(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "not-a-dir") + if err != nil { + t.Fatal(err) + } + tmpFile.Close() + + _, err = NewLogWriter(tmpFile.Name()+"/subdir", 1) + if err == nil { + t.Error("expected error when directory path is inside a file, got nil") + } +} + +func TestAppend_SingleRecord_WrittenToDisk(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + info, err := os.Stat(segmentPath(dir, 1)) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() == 0 { + t.Error("segment file is empty after Append") + } +} + +func TestAppend_MultipleRecords_AllWritten(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + records := []*Record{ + {Opcode: OpcodePut, Key: []byte("k1"), Value: []byte("v1")}, + {Opcode: OpcodePut, Key: []byte("k2"), Value: []byte("v2")}, + {Opcode: OpcodeDelete, Key: []byte("k1")}, + } + for _, r := range records { + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + } + + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if _, ok := mem.puts["k2"]; !ok { + t.Error("k2 not in memtable after replay") + } + found := false + for _, d := range mem.deletes { + if d == "k1" { + found = true + } + } + if !found { + t.Error("delete for k1 not replayed") + } +} + +func TestAppend_RecordRoundtrip_ViaReplay(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + original := &Record{Opcode: OpcodePut, Key: []byte("penguindb"), Value: []byte("rocks")} + if err := w.Append(original); err != nil { + t.Fatalf("Append: %v", err) + } + w.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if string(mem.puts["penguindb"]) != "rocks" { + t.Errorf("expected value 'rocks', got %q", mem.puts["penguindb"]) + } +} + +func TestAppend_AfterClose_ReturnsError(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")} + if err = w.Append(r); err == nil { + t.Error("expected error when Appending after Close, got nil") + } +} + +func TestAppend_EmptyKey_Rejected(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: []byte("v")} + if err = w.Append(r); err != ErrEmptyKey { + t.Errorf("expected ErrEmptyKey for empty key, got %v", err) + } +} + +func TestAppend_NilKey_Rejected(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: nil, Value: []byte("v")} + if err = w.Append(r); err != ErrEmptyKey { + t.Errorf("expected ErrEmptyKey for nil key, got %v", err) + } +} + +func TestAppend_EmptyValue_Allowed(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r := &Record{Opcode: OpcodePut, Key: []byte("k"), Value: []byte{}} + if err := w.Append(r); err != nil { + t.Errorf("unexpected error appending record with empty value: %v", err) + } +} + +func TestRotation_NewSegmentCreatedAfterSizeExceeded(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + w.currentSizeBytes = MaxSegmentSizeBytes - 1 + + r := &Record{Opcode: OpcodePut, Key: []byte("trigger"), Value: []byte("rotation")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + if _, err := os.Stat(segmentPath(dir, 2)); os.IsNotExist(err) { + t.Error("expected segment 000002.wal after rotation") + } +} + +func TestRotation_OldSegmentSyncedOnRotation(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + r1 := &Record{Opcode: OpcodePut, Key: []byte("before"), Value: []byte("rotation")} + if err := w.Append(r1); err != nil { + t.Fatalf("Append r1: %v", err) + } + + w.currentSizeBytes = MaxSegmentSizeBytes + + r2 := &Record{Opcode: OpcodePut, Key: []byte("after"), Value: []byte("rotation")} + if err := w.Append(r2); err != nil { + t.Fatalf("Append r2: %v", err) + } + + w.Close() + + for _, id := range []int{1, 2} { + info, err := os.Stat(segmentPath(dir, id)) + if err != nil { + t.Fatalf("stat segment %d: %v", id, err) + } + if info.Size() == 0 { + t.Errorf("segment %d is empty, expected data", id) + } + } +} + +func TestRotation_SizeResetAfterRotation(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + w.currentSizeBytes = MaxSegmentSizeBytes + r := &Record{Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + if w.currentSizeBytes >= MaxSegmentSizeBytes { + t.Errorf("currentSizeBytes = %d after rotation, expected < %d", + w.currentSizeBytes, MaxSegmentSizeBytes) + } +} + +func TestAppend_ConcurrentWrites_NoDataRace(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + const numGoroutines = 50 + const recordsPerGoroutine = 20 + + var wg sync.WaitGroup + var errCount int64 + + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < recordsPerGoroutine; i++ { + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("goroutine-%d-key-%d", id, i)), + Value: []byte(fmt.Sprintf("value-%d", i)), + } + if err := w.Append(r); err != nil { + atomic.AddInt64(&errCount, 1) + } + } + }(g) + } + + wg.Wait() + + if errCount > 0 { + t.Errorf("%d Append calls failed under concurrency", errCount) + } +} + +func TestAppend_ConcurrentWrites_AllRecordsRecoverable(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numRecords = 100 + keys := make([]string, numRecords) + for i := 0; i < numRecords; i++ { + keys[i] = fmt.Sprintf("concurrent-key-%04d", i) + } + + var wg sync.WaitGroup + for i := 0; i < numRecords; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(keys[idx]), + Value: []byte("ok"), + } + _ = w.Append(r) + }(i) + } + wg.Wait() + w.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + for _, k := range keys { + if _, ok := mem.puts[k]; !ok { + t.Errorf("key %q missing after concurrent write + replay", k) + } + } +} + +func TestClose_IsIdempotent(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + if err := w.Close(); err != nil { + t.Errorf("first Close error: %v", err) + } + if err := w.Close(); err != nil { + t.Logf("second Close returned (expected nil): %v", err) + } +} + +func TestClose_BlocksUntilWorkerDone(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + done := make(chan struct{}) + go func() { + w.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Error("Close did not return within 3 seconds") + } +} + +func TestClose_SyncsDataToDisk(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + r := &Record{Opcode: OpcodePut, Key: []byte("durable"), Value: []byte("yes")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if string(mem.puts["durable"]) != "yes" { + t.Error("data written before Close was not recovered") + } +} + +func TestMaxSegmentSizeBytes_Is32MB(t *testing.T) { + expected := int64(32 * 1024 * 1024) + if MaxSegmentSizeBytes != expected { + t.Errorf("MaxSegmentSizeBytes = %d, want %d (32 MiB)", MaxSegmentSizeBytes, expected) + } +} + +func TestBatchWorker_GroupCommit_AllTicketsSignalled(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + const n = 200 + errs := make([]error, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("batch-key-%d", idx)), + Value: []byte("v"), + } + errs[idx] = w.Append(r) + }(i) + } + wg.Wait() + + for i, e := range errs { + if e != nil { + t.Errorf("Append[%d] returned error: %v", i, e) + } + } +} From 1903a89b5d44cd3700efe3d4ab45a5a912b1098b Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 19:57:16 +0530 Subject: [PATCH 06/18] Fixed missing package bugg and panic on closing bugs --- internal/storage/memtable/skiplist_test.go | 25 +++++++- internal/storage/wal/recovery.go | 15 ++++- internal/storage/wal/recovery_test.go | 66 ++++++++++++++++++++++ internal/storage/wal/writer.go | 17 +++++- 4 files changed, 116 insertions(+), 7 deletions(-) diff --git a/internal/storage/memtable/skiplist_test.go b/internal/storage/memtable/skiplist_test.go index 9fc5a5c..bbdb9f5 100644 --- a/internal/storage/memtable/skiplist_test.go +++ b/internal/storage/memtable/skiplist_test.go @@ -2,6 +2,7 @@ package memtable import ( "bytes" + "errors" "fmt" "sync" "testing" @@ -228,7 +229,9 @@ func TestSkipList_Concurrency(t *testing.T) { defer waitGroup.Done() key := []byte(fmt.Sprintf("key-%03d", id)) val := []byte(fmt.Sprintf("val-%03d", id)) - _ = skipList.Put(key, val) + if err := skipList.Put(key, val); err != nil { + t.Errorf("put failed for %s: %v", key, err) + } }(i) } @@ -237,7 +240,9 @@ func TestSkipList_Concurrency(t *testing.T) { go func(id int) { defer waitGroup.Done() key := []byte(fmt.Sprintf("key-%03d", id)) - _ = skipList.Delete(key) + if err := skipList.Delete(key); err != nil { + t.Errorf("delete failed for %s: %v", key, err) + } }(i) } @@ -253,6 +258,22 @@ func TestSkipList_Concurrency(t *testing.T) { } waitGroup.Wait() + + for i := 0; i < 100; i++ { + key := []byte(fmt.Sprintf("key-%03d", i)) + expected := []byte(fmt.Sprintf("val-%03d", i)) + got, err := skipList.Get(key) + if err != nil || !bytes.Equal(got, expected) { + t.Fatalf("unexpected state for %s: got (%q, %v), want (%q, nil)", key, got, err, expected) + } + } + + for i := 100; i < 200; i++ { + key := []byte(fmt.Sprintf("key-%03d", i)) + if _, err := skipList.Get(key); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected ErrKeyNotFound for deleted/tombstoned key %s, got %v", key, err) + } + } } // TestSkipList_SortedOrder verifies that the iterator always returns keys in diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index e464fc5..9e86426 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -63,12 +63,16 @@ func Replay(directory string, engine MemTable) (int, error) { return highestSegmentID, nil } -func replayFile(filePath string, engine MemTable) error { +func replayFile(filePath string, engine MemTable) (err error) { file, err := os.OpenFile(filePath, os.O_RDWR, 0644) if err != nil { return fmt.Errorf("unable to open WAL segment for reading: %w", err) } - defer file.Close() + defer func() { + if closeErr := file.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("failed to close WAL segment %s: %w", filePath, closeErr) + } + }() var validBytes int64 = 0 var recordsRecovered int @@ -93,6 +97,13 @@ func replayFile(filePath string, engine MemTable) error { } totalFrameSizeBytes := binary.LittleEndian.Uint32(headerBuffer[4:8]) + if totalFrameSizeBytes < 8 || totalFrameSizeBytes > 128*1024*1024 { + slog.Debug("invalid frame size in header, truncating segment", + "file", filepath.Base(filePath), + "frame_size", totalFrameSizeBytes, + "valid_bytes", validBytes) + return file.Truncate(validBytes) + } payloadSizeBytes := totalFrameSizeBytes - 8 payloadBuffer := make([]byte, payloadSizeBytes) diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go index 96f17c4..e516cf6 100644 --- a/internal/storage/wal/recovery_test.go +++ b/internal/storage/wal/recovery_test.go @@ -398,3 +398,69 @@ func TestReplay_SubdirectoriesAreIgnored(t *testing.T) { t.Errorf("nextID = %d, want 1", nextID) } } + +func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + validBytes := good.Marshal() + writeRecordsToFile(t, path, []*Record{good}) + + hdr := make([]byte, 8) + binary.LittleEndian.PutUint32(hdr[4:8], 7) + binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write(hdr) + f.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} + +func TestReplay_InvalidFrameSize_TooLarge_TruncatesFile(t *testing.T) { + dir := t.TempDir() + path := segmentPath(dir, 1) + + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + validBytes := good.Marshal() + writeRecordsToFile(t, path, []*Record{good}) + + hdr := make([]byte, 8) + binary.LittleEndian.PutUint32(hdr[4:8], 129*1024*1024) + binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) + + f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f.Write(hdr) + f.Close() + + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } +} diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 8c686ce..a1f6034 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -87,7 +87,12 @@ func (writer *LogWriter) Append(record *Record) error { select { case writer.ingestionChannel <- ticket: - return <-ticket.resultChan + select { + case err := <-ticket.resultChan: + return err + case <-writer.shutdownSignal: + return fmt.Errorf("database engine is currently shutting down, write rejected") + } case <-writer.shutdownSignal: return fmt.Errorf("database engine is currently shutting down, write rejected") } @@ -131,6 +136,8 @@ func (writer *LogWriter) batchWorker() { _, err := writer.activeFile.Write(writeBuffer) if err == nil { err = writer.activeFile.Sync() + } + if err == nil { writer.currentSizeBytes += int64(len(writeBuffer)) } @@ -153,8 +160,12 @@ func (writer *LogWriter) Close() error { close(writer.shutdownSignal) writer.workerWaitGroup.Wait() if writer.activeFile != nil { - _ = writer.activeFile.Sync() - closeErr = writer.activeFile.Close() + if syncErr := writer.activeFile.Sync(); syncErr != nil { + closeErr = syncErr + } + if err := writer.activeFile.Close(); err != nil && closeErr == nil { + closeErr = err + } } }) return closeErr From eabe20c787401678500c5be89eb6caf66f80ca3c Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 20:05:10 +0530 Subject: [PATCH 07/18] Fixed linting errors --- internal/storage/memtable/skiplist_test.go | 2 +- internal/storage/wal/record_test.go | 7 ++++--- internal/storage/wal/recovery.go | 10 +++++++--- internal/storage/wal/writer.go | 4 ++-- internal/storage/wal/writer_test.go | 5 +++-- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/internal/storage/memtable/skiplist_test.go b/internal/storage/memtable/skiplist_test.go index bbdb9f5..ade959a 100644 --- a/internal/storage/memtable/skiplist_test.go +++ b/internal/storage/memtable/skiplist_test.go @@ -201,7 +201,7 @@ func TestSkipList_StrictConcurrency(t *testing.T) { defer waitGroup.Done() for i := 0; i < 1000; i++ { val, err := skipList.Get(key) - if err != nil && err != ErrKeyNotFound { + if err != nil && !errors.Is(err, ErrKeyNotFound) { t.Errorf("unexpected error on Get: %v", err) } if err == nil { diff --git a/internal/storage/wal/record_test.go b/internal/storage/wal/record_test.go index 1b924a2..9e1793c 100644 --- a/internal/storage/wal/record_test.go +++ b/internal/storage/wal/record_test.go @@ -3,6 +3,7 @@ package wal import ( "bytes" "encoding/binary" + "errors" "hash/crc32" "testing" ) @@ -182,7 +183,7 @@ func TestUnmarshal_TruncatedFrame_TooShort(t *testing.T) { } for _, data := range cases { _, err := UnmarshalRecord(data) - if err != ErrTruncated { + if !errors.Is(err, ErrTruncated) { t.Errorf("input len=%d: expected ErrTruncated, got %v", len(data), err) } } @@ -215,7 +216,7 @@ func TestUnmarshal_CRCCorruption(t *testing.T) { frame := r.Marshal() tc.corruptFn(frame) _, err := UnmarshalRecord(frame) - if err != ErrInvalidCRC { + if !errors.Is(err, ErrInvalidCRC) { t.Errorf("expected ErrInvalidCRC, got %v", err) } }) @@ -231,7 +232,7 @@ func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { binary.LittleEndian.PutUint32(frame[0:4], newCRC) _, err := UnmarshalRecord(frame) - if err != ErrTruncated { + if !errors.Is(err, ErrTruncated) { t.Errorf("expected ErrTruncated when keyLen > frame size, got %v", err) } } diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index 9e86426..9014c6c 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -42,7 +42,9 @@ func Replay(directory string, engine MemTable) (int, error) { for _, fileName := range walFiles { var segmentID int - fmt.Sscanf(fileName, "%d.wal", &segmentID) + if _, scanErr := fmt.Sscanf(fileName, "%d.wal", &segmentID); scanErr != nil { + return 0, fmt.Errorf("invalid WAL filename %s: %w", fileName, scanErr) + } if segmentID > highestSegmentID { highestSegmentID = segmentID } @@ -64,7 +66,7 @@ func Replay(directory string, engine MemTable) (int, error) { } func replayFile(filePath string, engine MemTable) (err error) { - file, err := os.OpenFile(filePath, os.O_RDWR, 0644) + file, err := os.OpenFile(filePath, os.O_RDWR, 0o644) if err != nil { return fmt.Errorf("unable to open WAL segment for reading: %w", err) } @@ -115,7 +117,9 @@ func replayFile(filePath string, engine MemTable) (err error) { return file.Truncate(validBytes) } - fullFrame := append(headerBuffer, payloadBuffer...) + fullFrame := make([]byte, 8+payloadSizeBytes) + copy(fullFrame[:8], headerBuffer) + copy(fullFrame[8:], payloadBuffer) record, err := UnmarshalRecord(fullFrame) if err != nil { diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index a1f6034..4e6b848 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -28,7 +28,7 @@ type LogWriter struct { } func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { - if err := os.MkdirAll(directory, 0755); err != nil { + if err := os.MkdirAll(directory, 0o755); err != nil { return nil, fmt.Errorf("failed to initialize WAL directory structure at %s: %w", directory, err) } @@ -61,7 +61,7 @@ func (writer *LogWriter) rotateActiveFile() error { } segmentPath := filepath.Join(writer.directory, fmt.Sprintf("%06d.wal", writer.currentSegmentID)) - file, err := os.OpenFile(segmentPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + file, err := os.OpenFile(segmentPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return fmt.Errorf("failed to open new WAL segment %s: %w", segmentPath, err) } diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go index 09a8961..ead44bd 100644 --- a/internal/storage/wal/writer_test.go +++ b/internal/storage/wal/writer_test.go @@ -1,6 +1,7 @@ package wal import ( + "errors" "fmt" "os" "sync" @@ -170,7 +171,7 @@ func TestAppend_EmptyKey_Rejected(t *testing.T) { defer w.Close() r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: []byte("v")} - if err = w.Append(r); err != ErrEmptyKey { + if err = w.Append(r); !errors.Is(err, ErrEmptyKey) { t.Errorf("expected ErrEmptyKey for empty key, got %v", err) } } @@ -184,7 +185,7 @@ func TestAppend_NilKey_Rejected(t *testing.T) { defer w.Close() r := &Record{Opcode: OpcodePut, Key: nil, Value: []byte("v")} - if err = w.Append(r); err != ErrEmptyKey { + if err = w.Append(r); !errors.Is(err, ErrEmptyKey) { t.Errorf("expected ErrEmptyKey for nil key, got %v", err) } } From f687e117ff71f854f8b0ff3dfa69e9db634deb0c Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 20:15:54 +0530 Subject: [PATCH 08/18] Document the code to explain the functions and record format --- internal/storage/wal/errors.go | 7 +++++++ internal/storage/wal/record.go | 10 ++++++++++ internal/storage/wal/record_test.go | 20 ++++++++++++++++++++ internal/storage/wal/recovery.go | 8 ++++++++ internal/storage/wal/recovery_test.go | 27 +++++++++++++++++++++++++++ internal/storage/wal/writer.go | 15 +++++++++++++++ internal/storage/wal/writer_test.go | 21 +++++++++++++++++++++ 7 files changed, 108 insertions(+) diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go index dff1f31..9a1250b 100644 --- a/internal/storage/wal/errors.go +++ b/internal/storage/wal/errors.go @@ -1,9 +1,16 @@ +// Package wal implements a Write-Ahead Log (WAL) for penguin-db. +// It supports log appending, rotation, record serialization, and replay recovery. package wal import "errors" var ( + // ErrInvalidCRC is returned when a record's checksum does not match its payload. ErrInvalidCRC = errors.New("corrupt wal record: crc32 mismatch") + + // ErrTruncated is returned when the log record has an unexpected size or EOF is reached mid-record. ErrTruncated = errors.New("corrupt wal record: truncated payload") + + // ErrEmptyKey is returned when attempting to write a log record with a zero-length key. ErrEmptyKey = errors.New("wal record rejected: key must not be empty") ) diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go index 06a88d6..3fbd319 100644 --- a/internal/storage/wal/record.go +++ b/internal/storage/wal/record.go @@ -6,16 +6,23 @@ import ( ) const ( + // OpcodePut represents a put/insert operation in the WAL. OpcodePut uint8 = 0 + // OpcodeDelete represents a delete operation in the WAL. OpcodeDelete uint8 = 1 ) +// Record represents a single change logged in the WAL, wrapping an operation +// type (Opcode), key, and value payload. type Record struct { Opcode uint8 Key []byte Value []byte } +// Marshal serializes the Record into a binary frame with a 4-byte CRC checksum, +// a 4-byte frame size header, a 1-byte opcode, a 2-byte key length, and the +// raw key and value bytes. func (record *Record) Marshal() []byte { metadataAndDataSize := 3 + len(record.Key) + len(record.Value) totalFrameSizeBytes := 8 + metadataAndDataSize @@ -35,6 +42,9 @@ func (record *Record) Marshal() []byte { return frameBuffer } +// UnmarshalRecord deserializes a raw binary frame and reconstructs the original +// Record. It validates the data integrity using the CRC checksum and performs +// bounds checking on payload lengths. func UnmarshalRecord(frameData []byte) (*Record, error) { if len(frameData) < 11 { return nil, ErrTruncated diff --git a/internal/storage/wal/record_test.go b/internal/storage/wal/record_test.go index 9e1793c..1f79296 100644 --- a/internal/storage/wal/record_test.go +++ b/internal/storage/wal/record_test.go @@ -8,10 +8,13 @@ import ( "testing" ) +// buildValidFrame is a helper that constructs a valid binary frame from a Record. func buildValidFrame(r *Record) []byte { return r.Marshal() } +// TestMarshal_FrameLayout tests that the layout of the marshaled frame meets +// specification requirements (proper size, opcode location, and key/value placement). func TestMarshal_FrameLayout(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} frame := r.Marshal() @@ -44,6 +47,8 @@ func TestMarshal_FrameLayout(t *testing.T) { } } +// TestMarshal_CRCCoversPayload verifies that the CRC-32 checksum is calculated +// correctly over the rest of the frame payload. func TestMarshal_CRCCoversPayload(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("k"), Value: nil} frame := r.Marshal() @@ -54,6 +59,7 @@ func TestMarshal_CRCCoversPayload(t *testing.T) { } } +// TestMarshal_OpcodeDelete verifies that deletion records marshal with OpcodeDelete. func TestMarshal_OpcodeDelete(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} frame := r.Marshal() @@ -62,6 +68,7 @@ func TestMarshal_OpcodeDelete(t *testing.T) { } } +// TestMarshal_ZeroLengthKey tests marshaling records with empty or nil keys. func TestMarshal_ZeroLengthKey(t *testing.T) { cases := []struct { name string @@ -80,6 +87,7 @@ func TestMarshal_ZeroLengthKey(t *testing.T) { } } +// TestMarshal_ZeroLengthValue tests marshaling records with empty or nil values. func TestMarshal_ZeroLengthValue(t *testing.T) { cases := []struct { name string @@ -100,6 +108,7 @@ func TestMarshal_ZeroLengthValue(t *testing.T) { } } +// TestMarshal_LargePayload tests marshalling records with large keys and values. func TestMarshal_LargePayload(t *testing.T) { key := bytes.Repeat([]byte("k"), 1024) value := bytes.Repeat([]byte("v"), 4096) @@ -115,6 +124,8 @@ func TestMarshal_LargePayload(t *testing.T) { } } +// TestMarshal_BinaryKeyAndValue ensures that arbitrary binary data is preserved +// during marshalling and unmarshalling. func TestMarshal_BinaryKeyAndValue(t *testing.T) { key := []byte{0x00, 0xFF, 0x7F, 0x80} value := []byte{0x01, 0x02, 0x03} @@ -133,6 +144,7 @@ func TestMarshal_BinaryKeyAndValue(t *testing.T) { } } +// TestMarshal_Idempotent verifies that Marshal output is identical for successive calls. func TestMarshal_Idempotent(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("idempotent"), Value: []byte("yes")} if !bytes.Equal(r.Marshal(), r.Marshal()) { @@ -140,6 +152,7 @@ func TestMarshal_Idempotent(t *testing.T) { } } +// TestUnmarshal_RoundTrip_Put checks that a serialized Put record can be accurately reconstructed. func TestUnmarshal_RoundTrip_Put(t *testing.T) { original := &Record{Opcode: OpcodePut, Key: []byte("name"), Value: []byte("penguin")} frame := buildValidFrame(original) @@ -159,6 +172,7 @@ func TestUnmarshal_RoundTrip_Put(t *testing.T) { } } +// TestUnmarshal_RoundTrip_Delete checks that a serialized Delete record can be accurately reconstructed. func TestUnmarshal_RoundTrip_Delete(t *testing.T) { original := &Record{Opcode: OpcodeDelete, Key: []byte("to-remove"), Value: nil} frame := buildValidFrame(original) @@ -175,6 +189,7 @@ func TestUnmarshal_RoundTrip_Delete(t *testing.T) { } } +// TestUnmarshal_TruncatedFrame_TooShort verifies that short inputs return ErrTruncated. func TestUnmarshal_TruncatedFrame_TooShort(t *testing.T) { cases := [][]byte{ {}, @@ -189,6 +204,7 @@ func TestUnmarshal_TruncatedFrame_TooShort(t *testing.T) { } } +// TestUnmarshal_MinimalFrame_EmptyKeyNoValue verifies parsing of minimally sized frames. func TestUnmarshal_MinimalFrame_EmptyKeyNoValue(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: nil} frame := r.Marshal() @@ -201,6 +217,7 @@ func TestUnmarshal_MinimalFrame_EmptyKeyNoValue(t *testing.T) { } } +// TestUnmarshal_CRCCorruption tests that altered headers or payloads result in checksum errors. func TestUnmarshal_CRCCorruption(t *testing.T) { cases := []struct { name string @@ -223,6 +240,7 @@ func TestUnmarshal_CRCCorruption(t *testing.T) { } } +// TestUnmarshal_KeyLengthExceedsFrame checks that key lengths exceeding frame limits are rejected. func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("ab"), Value: []byte("v")} frame := r.Marshal() @@ -237,6 +255,7 @@ func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { } } +// TestUnmarshal_NilValueForDeleteRecord checks that deleted records correctly parse nil values. func TestUnmarshal_NilValueForDeleteRecord(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} frame := r.Marshal() @@ -250,6 +269,7 @@ func TestUnmarshal_NilValueForDeleteRecord(t *testing.T) { } } +// TestOpcodeValues validates correctness of package opcode constants. func TestOpcodeValues(t *testing.T) { if OpcodePut != 0 { t.Errorf("OpcodePut = %d, want 0", OpcodePut) diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index 9014c6c..dc7f2f3 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -11,11 +11,16 @@ import ( "sort" ) +// MemTable defines the minimal interface for in-memory storage engine components +// that can consume replayed WAL records during recovery. type MemTable interface { Put(key, value []byte) error Delete(key []byte) error } +// Replay scans the specified WAL directory, identifies all segment files +// matching the *.wal pattern, and replays their logged operations onto the +// target MemTable. It returns the highest segment ID found or 1 if fresh. func Replay(directory string, engine MemTable) (int, error) { slog.Debug("starting WAL recovery sequence", "directory", directory) @@ -65,6 +70,9 @@ func Replay(directory string, engine MemTable) (int, error) { return highestSegmentID, nil } +// replayFile opens a single WAL segment, reads it frame-by-frame, validates +// check-sums and frame sizes, and applies the operations onto the MemTable. +// If it encounters corruption or a partial write, it truncates the segment. func replayFile(filePath string, engine MemTable) (err error) { file, err := os.OpenFile(filePath, os.O_RDWR, 0o644) if err != nil { diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go index e516cf6..8e18c52 100644 --- a/internal/storage/wal/recovery_test.go +++ b/internal/storage/wal/recovery_test.go @@ -9,6 +9,8 @@ import ( "testing" ) +// mockMemTable implements the MemTable interface for tracking replayed actions +// and simulating write failures during recovery tests. type mockMemTable struct { puts map[string][]byte deletes []string @@ -16,10 +18,12 @@ type mockMemTable struct { delErr error } +// newMockMemTable constructs an empty mockMemTable. func newMockMemTable() *mockMemTable { return &mockMemTable{puts: make(map[string][]byte)} } +// Put records a put operation or returns an injected failure error. func (m *mockMemTable) Put(key, value []byte) error { if m.putErr != nil { return m.putErr @@ -28,6 +32,7 @@ func (m *mockMemTable) Put(key, value []byte) error { return nil } +// Delete records a delete operation or returns an injected failure error. func (m *mockMemTable) Delete(key []byte) error { if m.delErr != nil { return m.delErr @@ -36,6 +41,7 @@ func (m *mockMemTable) Delete(key []byte) error { return nil } +// writeRecordsToFile serialized and appends a list of Records to a segment file. func writeRecordsToFile(t *testing.T, path string, records []*Record) { t.Helper() f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) @@ -50,10 +56,13 @@ func writeRecordsToFile(t *testing.T, path string, records []*Record) { } } +// segmentPath constructs the absolute filepath for a given segment ID. func segmentPath(dir string, id int) string { return filepath.Join(dir, fmt.Sprintf("%06d.wal", id)) } +// TestReplay_NonExistentDirectory_ReturnsFreshSegmentID verifies recovery behaves +// correctly if the WAL directory does not exist. func TestReplay_NonExistentDirectory_ReturnsFreshSegmentID(t *testing.T) { dir := filepath.Join(t.TempDir(), "no-such-wal") mem := newMockMemTable() @@ -70,6 +79,7 @@ func TestReplay_NonExistentDirectory_ReturnsFreshSegmentID(t *testing.T) { } } +// TestReplay_EmptyDirectory_ReturnsFreshSegmentID checks recovery on an empty directory. func TestReplay_EmptyDirectory_ReturnsFreshSegmentID(t *testing.T) { dir := t.TempDir() mem := newMockMemTable() @@ -83,6 +93,7 @@ func TestReplay_EmptyDirectory_ReturnsFreshSegmentID(t *testing.T) { } } +// TestReplay_NonWALFilesIgnored checks that recovery ignores files without .wal extensions. func TestReplay_NonWALFilesIgnored(t *testing.T) { dir := t.TempDir() for _, name := range []string{"data.sst", "manifest", "000001.log", "LOCK"} { @@ -100,6 +111,7 @@ func TestReplay_NonWALFilesIgnored(t *testing.T) { } } +// TestReplay_SingleSegment_AllPuts checks that a simple segment with Put operations is fully replayed. func TestReplay_SingleSegment_AllPuts(t *testing.T) { dir := t.TempDir() records := []*Record{ @@ -124,6 +136,7 @@ func TestReplay_SingleSegment_AllPuts(t *testing.T) { } } +// TestReplay_SingleSegment_Deletes checks that Delete operations are replayed. func TestReplay_SingleSegment_Deletes(t *testing.T) { dir := t.TempDir() writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ @@ -140,6 +153,7 @@ func TestReplay_SingleSegment_Deletes(t *testing.T) { } } +// TestReplay_MultipleSegments_ReplayedInOrder verifies segment logs are recovered in sequential order. func TestReplay_MultipleSegments_ReplayedInOrder(t *testing.T) { dir := t.TempDir() @@ -166,6 +180,7 @@ func TestReplay_MultipleSegments_ReplayedInOrder(t *testing.T) { } } +// TestReplay_ReturnsHighestSegmentID verifies recovery identifies and returns the highest active segment ID. func TestReplay_ReturnsHighestSegmentID(t *testing.T) { dir := t.TempDir() for _, id := range []int{1, 5, 10} { @@ -183,6 +198,7 @@ func TestReplay_ReturnsHighestSegmentID(t *testing.T) { } } +// TestReplay_UnknownOpcode_Ignored verifies that records with invalid opcodes are ignored during replay. func TestReplay_UnknownOpcode_Ignored(t *testing.T) { dir := t.TempDir() writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ @@ -198,6 +214,7 @@ func TestReplay_UnknownOpcode_Ignored(t *testing.T) { } } +// TestReplay_CorruptedCRC_TruncatesFile checks that recovery detects CRC mismatches and truncates. func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) @@ -233,6 +250,7 @@ func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { } } +// TestReplay_TruncatedHeader_TruncatesFile checks truncation when the frame header is truncated. func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) @@ -254,6 +272,7 @@ func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { } } +// TestReplay_TruncatedPayload_TruncatesFile checks truncation when the frame payload is truncated. func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) @@ -282,6 +301,7 @@ func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { } } +// TestReplay_EmptySegmentFile_NoError checks that an empty WAL segment is handled gracefully. func TestReplay_EmptySegmentFile_NoError(t *testing.T) { dir := t.TempDir() os.WriteFile(segmentPath(dir, 1), []byte{}, 0644) @@ -296,6 +316,7 @@ func TestReplay_EmptySegmentFile_NoError(t *testing.T) { } } +// TestReplay_MemTablePutError_PropagatesError checks failure propagation on MemTable Put errors. func TestReplay_MemTablePutError_PropagatesError(t *testing.T) { dir := t.TempDir() writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ @@ -310,6 +331,7 @@ func TestReplay_MemTablePutError_PropagatesError(t *testing.T) { } } +// TestReplay_MemTableDeleteError_PropagatesError checks failure propagation on MemTable Delete errors. func TestReplay_MemTableDeleteError_PropagatesError(t *testing.T) { dir := t.TempDir() writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ @@ -324,6 +346,7 @@ func TestReplay_MemTableDeleteError_PropagatesError(t *testing.T) { } } +// TestReplay_MixedPutsAndDeletes_CorrectOrder verifies interleaved Puts/Deletes replay in correct order. func TestReplay_MixedPutsAndDeletes_CorrectOrder(t *testing.T) { dir := t.TempDir() mem := newMockMemTable() @@ -357,6 +380,7 @@ func TestReplay_MixedPutsAndDeletes_CorrectOrder(t *testing.T) { } } +// TestReplay_SegmentsSortedNumerically checks sorting logic on WAL segments with double digits. func TestReplay_SegmentsSortedNumerically(t *testing.T) { dir := t.TempDir() @@ -379,6 +403,7 @@ func TestReplay_SegmentsSortedNumerically(t *testing.T) { } } +// TestReplay_SubdirectoriesAreIgnored verifies subdirectories inside the WAL directory are skipped. func TestReplay_SubdirectoriesAreIgnored(t *testing.T) { dir := t.TempDir() subDir := filepath.Join(dir, "archive.wal") @@ -399,6 +424,7 @@ func TestReplay_SubdirectoriesAreIgnored(t *testing.T) { } } +// TestReplay_InvalidFrameSize_TooSmall_TruncatesFile verifies truncation on underflowing frame sizes. func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) @@ -432,6 +458,7 @@ func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { } } +// TestReplay_InvalidFrameSize_TooLarge_TruncatesFile verifies truncation on excessively large frame sizes. func TestReplay_InvalidFrameSize_TooLarge_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 4e6b848..9b7992f 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -10,11 +10,16 @@ import ( const MaxSegmentSizeBytes int64 = 32 * 1024 * 1024 +// commitTicket represents an ingestion task containing serialized record data +// and a channel to communicate the result of the log write and fsync. type commitTicket struct { frameData []byte resultChan chan error } +// LogWriter manages sequential appending to WAL segment files. It coordinates +// concurrent appends, runs a batching background worker to flush writes, and +// handles file rotation when a segment exceeds its capacity. type LogWriter struct { directory string activeFile *os.File @@ -27,6 +32,8 @@ type LogWriter struct { closeOnce sync.Once } +// NewLogWriter creates a new LogWriter instance, initializing the WAL directory, +// creating the initial active segment, and launching the background batch worker. func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { if err := os.MkdirAll(directory, 0o755); err != nil { return nil, fmt.Errorf("failed to initialize WAL directory structure at %s: %w", directory, err) @@ -49,6 +56,8 @@ func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { return writer, nil } +// rotateActiveFile closes the current active segment file after fsyncing its data, +// increments the segment ID, and opens a new segment file for write access. func (writer *LogWriter) rotateActiveFile() error { if writer.activeFile != nil { if err := writer.activeFile.Sync(); err != nil { @@ -71,6 +80,8 @@ func (writer *LogWriter) rotateActiveFile() error { return nil } +// Append writes a single Record into the Write-Ahead Log. It blocks until the +// record is durably persisted (written and synced) or the log is closed. func (writer *LogWriter) Append(record *Record) error { if len(record.Key) == 0 { return ErrEmptyKey @@ -98,6 +109,8 @@ func (writer *LogWriter) Append(record *Record) error { } } +// batchWorker runs in a background goroutine, receiving commit tickets from the +// ingestion channel, batching them, writing to the active file, and syncing them. func (writer *LogWriter) batchWorker() { defer writer.workerWaitGroup.Done() @@ -154,6 +167,8 @@ func (writer *LogWriter) batchWorker() { } } +// Close closes the active WAL segment file, terminates the background worker, +// and ensures all pending writes have been durably synced to disk. func (writer *LogWriter) Close() error { var closeErr error writer.closeOnce.Do(func() { diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go index ead44bd..4413c36 100644 --- a/internal/storage/wal/writer_test.go +++ b/internal/storage/wal/writer_test.go @@ -10,6 +10,7 @@ import ( "time" ) +// TestNewLogWriter_CreatesDirectory checks that initializing LogWriter creates the target directory. func TestNewLogWriter_CreatesDirectory(t *testing.T) { base := t.TempDir() dir := fmt.Sprintf("%s/nested/wal", base) @@ -25,6 +26,7 @@ func TestNewLogWriter_CreatesDirectory(t *testing.T) { } } +// TestNewLogWriter_CreatesFirstSegmentFile checks that the first log segment is created upon initialization. func TestNewLogWriter_CreatesFirstSegmentFile(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -38,6 +40,7 @@ func TestNewLogWriter_CreatesFirstSegmentFile(t *testing.T) { } } +// TestNewLogWriter_RespectsStartSegmentID verifies that the writer uses the provided initial segment ID. func TestNewLogWriter_RespectsStartSegmentID(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 7) @@ -51,6 +54,7 @@ func TestNewLogWriter_RespectsStartSegmentID(t *testing.T) { } } +// TestNewLogWriter_InvalidDirectory_ReturnsError checks that invalid directories return initialization errors. func TestNewLogWriter_InvalidDirectory_ReturnsError(t *testing.T) { tmpFile, err := os.CreateTemp(t.TempDir(), "not-a-dir") if err != nil { @@ -64,6 +68,7 @@ func TestNewLogWriter_InvalidDirectory_ReturnsError(t *testing.T) { } } +// TestAppend_SingleRecord_WrittenToDisk checks that a single record write changes the file size on disk. func TestAppend_SingleRecord_WrittenToDisk(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -86,6 +91,7 @@ func TestAppend_SingleRecord_WrittenToDisk(t *testing.T) { } } +// TestAppend_MultipleRecords_AllWritten checks that multiple sequential records are successfully appended. func TestAppend_MultipleRecords_AllWritten(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -126,6 +132,7 @@ func TestAppend_MultipleRecords_AllWritten(t *testing.T) { } } +// TestAppend_RecordRoundtrip_ViaReplay verifies that written records are completely recoverable. func TestAppend_RecordRoundtrip_ViaReplay(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -148,6 +155,7 @@ func TestAppend_RecordRoundtrip_ViaReplay(t *testing.T) { } } +// TestAppend_AfterClose_ReturnsError checks that writes are rejected after Close is called. func TestAppend_AfterClose_ReturnsError(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -162,6 +170,7 @@ func TestAppend_AfterClose_ReturnsError(t *testing.T) { } } +// TestAppend_EmptyKey_Rejected verifies empty key appends are rejected with ErrEmptyKey. func TestAppend_EmptyKey_Rejected(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -176,6 +185,7 @@ func TestAppend_EmptyKey_Rejected(t *testing.T) { } } +// TestAppend_NilKey_Rejected verifies nil key appends are rejected with ErrEmptyKey. func TestAppend_NilKey_Rejected(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -190,6 +200,7 @@ func TestAppend_NilKey_Rejected(t *testing.T) { } } +// TestAppend_EmptyValue_Allowed checks that empty/nil values are allowed in WAL entries. func TestAppend_EmptyValue_Allowed(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -204,6 +215,7 @@ func TestAppend_EmptyValue_Allowed(t *testing.T) { } } +// TestRotation_NewSegmentCreatedAfterSizeExceeded checks that segment file rotates on limit overrun. func TestRotation_NewSegmentCreatedAfterSizeExceeded(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -224,6 +236,7 @@ func TestRotation_NewSegmentCreatedAfterSizeExceeded(t *testing.T) { } } +// TestRotation_OldSegmentSyncedOnRotation verifies old segment files are synced upon rotation. func TestRotation_OldSegmentSyncedOnRotation(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -257,6 +270,7 @@ func TestRotation_OldSegmentSyncedOnRotation(t *testing.T) { } } +// TestRotation_SizeResetAfterRotation verifies that segment size tracking resets after file rotation. func TestRotation_SizeResetAfterRotation(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -277,6 +291,7 @@ func TestRotation_SizeResetAfterRotation(t *testing.T) { } } +// TestAppend_ConcurrentWrites_NoDataRace validates concurrent WAL writes do not cause data races. func TestAppend_ConcurrentWrites_NoDataRace(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -315,6 +330,7 @@ func TestAppend_ConcurrentWrites_NoDataRace(t *testing.T) { } } +// TestAppend_ConcurrentWrites_AllRecordsRecoverable verifies all concurrent writes are cleanly replayed. func TestAppend_ConcurrentWrites_AllRecordsRecoverable(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -355,6 +371,7 @@ func TestAppend_ConcurrentWrites_AllRecordsRecoverable(t *testing.T) { } } +// TestClose_IsIdempotent validates that closing multiple times behaves correctly. func TestClose_IsIdempotent(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -370,6 +387,7 @@ func TestClose_IsIdempotent(t *testing.T) { } } +// TestClose_BlocksUntilWorkerDone checks that Close blocks until background activities terminate. func TestClose_BlocksUntilWorkerDone(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -390,6 +408,7 @@ func TestClose_BlocksUntilWorkerDone(t *testing.T) { } } +// TestClose_SyncsDataToDisk verifies that closing the log writer flushes in-flight data. func TestClose_SyncsDataToDisk(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) @@ -414,6 +433,7 @@ func TestClose_SyncsDataToDisk(t *testing.T) { } } +// TestMaxSegmentSizeBytes_Is32MB checks the segment size boundary constant. func TestMaxSegmentSizeBytes_Is32MB(t *testing.T) { expected := int64(32 * 1024 * 1024) if MaxSegmentSizeBytes != expected { @@ -421,6 +441,7 @@ func TestMaxSegmentSizeBytes_Is32MB(t *testing.T) { } } +// TestBatchWorker_GroupCommit_AllTicketsSignalled verifies that all concurrent ticket requests receive replies. func TestBatchWorker_GroupCommit_AllTicketsSignalled(t *testing.T) { dir := t.TempDir() w, err := NewLogWriter(dir, 1) From 4e34f853a6ba71bc42cd8cf97408c99448361fb5 Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 21:39:05 +0530 Subject: [PATCH 09/18] Fixed space formatting linting errors --- internal/storage/wal/errors.go | 4 ++-- internal/storage/wal/record.go | 2 +- internal/storage/wal/recovery.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go index 9a1250b..b19c51b 100644 --- a/internal/storage/wal/errors.go +++ b/internal/storage/wal/errors.go @@ -9,8 +9,8 @@ var ( ErrInvalidCRC = errors.New("corrupt wal record: crc32 mismatch") // ErrTruncated is returned when the log record has an unexpected size or EOF is reached mid-record. - ErrTruncated = errors.New("corrupt wal record: truncated payload") + ErrTruncated = errors.New("corrupt wal record: truncated payload") // ErrEmptyKey is returned when attempting to write a log record with a zero-length key. - ErrEmptyKey = errors.New("wal record rejected: key must not be empty") + ErrEmptyKey = errors.New("wal record rejected: key must not be empty") ) diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go index 3fbd319..65d8903 100644 --- a/internal/storage/wal/record.go +++ b/internal/storage/wal/record.go @@ -7,7 +7,7 @@ import ( const ( // OpcodePut represents a put/insert operation in the WAL. - OpcodePut uint8 = 0 + OpcodePut uint8 = 0 // OpcodeDelete represents a delete operation in the WAL. OpcodeDelete uint8 = 1 ) diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index dc7f2f3..81eab2c 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -84,7 +84,7 @@ func replayFile(filePath string, engine MemTable) (err error) { } }() - var validBytes int64 = 0 + var validBytes int64 var recordsRecovered int headerBuffer := make([]byte, 8) From 96c242b5941465138d9bd77d9fdd1c2a99fed072 Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Fri, 5 Jun 2026 22:28:15 +0530 Subject: [PATCH 10/18] Add LF line ending and added few tests and bug fixes --- internal/storage/wal/recovery.go | 5 +-- internal/storage/wal/recovery_test.go | 45 +++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index 81eab2c..b0a555c 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -47,8 +47,9 @@ func Replay(directory string, engine MemTable) (int, error) { for _, fileName := range walFiles { var segmentID int - if _, scanErr := fmt.Sscanf(fileName, "%d.wal", &segmentID); scanErr != nil { - return 0, fmt.Errorf("invalid WAL filename %s: %w", fileName, scanErr) + if n, _ := fmt.Sscanf(fileName, "%d.wal", &segmentID); n != 1 { + slog.Debug("skipping WAL file with unexpected name format", "file", fileName) + continue } if segmentID > highestSegmentID { highestSegmentID = segmentID diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go index 8e18c52..acf5fe8 100644 --- a/internal/storage/wal/recovery_test.go +++ b/internal/storage/wal/recovery_test.go @@ -41,7 +41,7 @@ func (m *mockMemTable) Delete(key []byte) error { return nil } -// writeRecordsToFile serialized and appends a list of Records to a segment file. +// writeRecordsToFile serializes and appends a list of Records to a segment file. func writeRecordsToFile(t *testing.T, path string, records []*Record) { t.Helper() f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) @@ -225,7 +225,10 @@ func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { badFrame := (&Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("x")}).Marshal() badFrame[12] ^= 0xFF - f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + t.Fatal(err) + } f.Write(badFrame) f.Close() @@ -437,7 +440,10 @@ func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { binary.LittleEndian.PutUint32(hdr[4:8], 7) binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) - f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + t.Fatal(err) + } f.Write(hdr) f.Close() @@ -491,3 +497,36 @@ func TestReplay_InvalidFrameSize_TooLarge_TruncatesFile(t *testing.T) { t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) } } + +// TestReplay_MalformedWALFilename_Skipped verifies that files ending in .wal but not starting with a number are skipped during replay. +func TestReplay_MalformedWALFilename_Skipped(t *testing.T) { + dir := t.TempDir() + + goodPath := segmentPath(dir, 1) + good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} + writeRecordsToFile(t, goodPath, []*Record{good}) + + badPath := filepath.Join(dir, "malformed.wal") + bad := &Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("val")} + f, err := os.Create(badPath) + if err != nil { + t.Fatal(err) + } + f.Write(bad.Marshal()) + f.Close() + + mem := newMockMemTable() + nextID, err := Replay(dir, mem) + if err != nil { + t.Fatalf("unexpected error replaying: %v", err) + } + if nextID != 1 { + t.Errorf("nextID = %d, want 1", nextID) + } + if _, ok := mem.puts["good"]; !ok { + t.Error("good record was not applied") + } + if _, ok := mem.puts["bad"]; ok { + t.Error("bad record in malformed file was replayed, but should have been skipped") + } +} From 5283cbdc64de786bee9b85d34afe07e63ed7e1b7 Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Sat, 6 Jun 2026 22:48:47 +0530 Subject: [PATCH 11/18] Remove magic numbers --- internal/storage/wal/record.go | 73 ++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go index 65d8903..c49d2e1 100644 --- a/internal/storage/wal/record.go +++ b/internal/storage/wal/record.go @@ -20,24 +20,53 @@ type Record struct { Value []byte } -// Marshal serializes the Record into a binary frame with a 4-byte CRC checksum, -// a 4-byte frame size header, a 1-byte opcode, a 2-byte key length, and the -// raw key and value bytes. +// Define sizes for each field in the frame +const ( + checksumSize = 4 + frameSizeSize = 4 + opcodeSize = 1 + keyLengthSize = 2 + + // Fixed header size is the sum of all fixed-length fields + fixedHeaderSize = checksumSize + frameSizeSize + opcodeSize + keyLengthSize +) + +// Define offsets for each field to eliminate magic slice indices +const ( + checksumOffset = 0 + frameSizeOffset = checksumOffset + checksumSize + opcodeOffset = frameSizeOffset + frameSizeSize + keyLengthOffset = opcodeOffset + opcodeSize + keyOffset = keyLengthOffset + keyLengthSize +) + +// Marshal serializes the Record into a binary frame. +// +// Frame Layout: +// +-------------+-------------+----------+------------+-----------+-----------+ +// | Checksum | Frame Size | Opcode | Key Length | Key | Value | +// | (4 bytes) | (4 bytes) | (1 byte) | (2 bytes) | (n bytes) | (m bytes) | +// +-------------+-------------+----------+------------+-----------+-----------+ +// +// Note: The Checksum (CRC32) covers all bytes starting from the Frame Size. func (record *Record) Marshal() []byte { - metadataAndDataSize := 3 + len(record.Key) + len(record.Value) - totalFrameSizeBytes := 8 + metadataAndDataSize + keyLen := len(record.Key) + valLen := len(record.Value) + + totalFrameSizeBytes := fixedHeaderSize + keyLen + valLen frameBuffer := make([]byte, totalFrameSizeBytes) - binary.LittleEndian.PutUint32(frameBuffer[4:8], uint32(totalFrameSizeBytes)) - frameBuffer[8] = record.Opcode - binary.LittleEndian.PutUint16(frameBuffer[9:11], uint16(len(record.Key))) + binary.LittleEndian.PutUint32(frameBuffer[frameSizeOffset:opcodeOffset], uint32(totalFrameSizeBytes)) + frameBuffer[opcodeOffset] = record.Opcode + binary.LittleEndian.PutUint16(frameBuffer[keyLengthOffset:keyOffset], uint16(keyLen)) - copy(frameBuffer[11:], record.Key) - copy(frameBuffer[11+len(record.Key):], record.Value) + copy(frameBuffer[keyOffset:], record.Key) + valueOffset := keyOffset + keyLen + copy(frameBuffer[valueOffset:], record.Value) - calculatedChecksum := crc32.ChecksumIEEE(frameBuffer[4:]) - binary.LittleEndian.PutUint32(frameBuffer[0:4], calculatedChecksum) + calculatedChecksum := crc32.ChecksumIEEE(frameBuffer[frameSizeOffset:]) + binary.LittleEndian.PutUint32(frameBuffer[checksumOffset:frameSizeOffset], calculatedChecksum) return frameBuffer } @@ -46,32 +75,34 @@ func (record *Record) Marshal() []byte { // Record. It validates the data integrity using the CRC checksum and performs // bounds checking on payload lengths. func UnmarshalRecord(frameData []byte) (*Record, error) { - if len(frameData) < 11 { + if len(frameData) < fixedHeaderSize { return nil, ErrTruncated } - storedChecksum := binary.LittleEndian.Uint32(frameData[0:4]) - calculatedChecksum := crc32.ChecksumIEEE(frameData[4:]) + storedChecksum := binary.LittleEndian.Uint32(frameData[checksumOffset:frameSizeOffset]) + calculatedChecksum := crc32.ChecksumIEEE(frameData[frameSizeOffset:]) if storedChecksum != calculatedChecksum { return nil, ErrInvalidCRC } - extractedOpcode := frameData[8] - extractedKeyLength := binary.LittleEndian.Uint16(frameData[9:11]) + extractedOpcode := frameData[opcodeOffset] + extractedKeyLength := binary.LittleEndian.Uint16(frameData[keyLengthOffset:keyOffset]) - if len(frameData) < 11+int(extractedKeyLength) { + if len(frameData) < fixedHeaderSize+int(extractedKeyLength) { return nil, ErrTruncated } extractedKey := make([]byte, extractedKeyLength) - copy(extractedKey, frameData[11:11+int(extractedKeyLength)]) + copy(extractedKey, frameData[keyOffset:keyOffset+int(extractedKeyLength)]) - extractedValueLength := len(frameData) - (11 + int(extractedKeyLength)) + extractedValueLength := len(frameData) - (fixedHeaderSize + int(extractedKeyLength)) var extractedValue []byte + if extractedValueLength > 0 { extractedValue = make([]byte, extractedValueLength) - copy(extractedValue, frameData[11+extractedKeyLength:]) + valueOffset := keyOffset + int(extractedKeyLength) + copy(extractedValue, frameData[valueOffset:]) } return &Record{ From 26218d7dd337714245b609a88acefe87c2835979 Mon Sep 17 00:00:00 2001 From: Souvik606 Date: Mon, 8 Jun 2026 13:05:14 +0530 Subject: [PATCH 12/18] Fixed issues of race condition between different channels and file sync issues --- internal/storage/wal/recovery.go | 24 +++++- internal/storage/wal/writer.go | 138 +++++++++++++++++-------------- 2 files changed, 97 insertions(+), 65 deletions(-) diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index b0a555c..82f13e0 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -102,7 +102,11 @@ func replayFile(filePath string, engine MemTable) (err error) { slog.Debug("unexpected EOF in header, truncating segment", "file", filepath.Base(filePath), "valid_bytes", validBytes) - return file.Truncate(validBytes) + + if truncErr := file.Truncate(validBytes); truncErr != nil { + return truncErr + } + return file.Sync() } return fmt.Errorf("unexpected disk error reading frame header: %w", err) } @@ -113,7 +117,11 @@ func replayFile(filePath string, engine MemTable) (err error) { "file", filepath.Base(filePath), "frame_size", totalFrameSizeBytes, "valid_bytes", validBytes) - return file.Truncate(validBytes) + + if truncErr := file.Truncate(validBytes); truncErr != nil { + return truncErr + } + return file.Sync() } payloadSizeBytes := totalFrameSizeBytes - 8 @@ -123,7 +131,11 @@ func replayFile(filePath string, engine MemTable) (err error) { slog.Debug("unexpected EOF in payload, truncating segment", "file", filepath.Base(filePath), "valid_bytes", validBytes) - return file.Truncate(validBytes) + + if truncErr := file.Truncate(validBytes); truncErr != nil { + return truncErr + } + return file.Sync() } fullFrame := make([]byte, 8+payloadSizeBytes) @@ -137,7 +149,11 @@ func replayFile(filePath string, engine MemTable) (err error) { "file", filepath.Base(filePath), "valid_bytes", validBytes, "error", err) - return file.Truncate(validBytes) + + if truncErr := file.Truncate(validBytes); truncErr != nil { + return truncErr + } + return file.Sync() } return fmt.Errorf("failed to decode valid frame payload: %w", err) } diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 9b7992f..e1b6991 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -9,6 +9,7 @@ import ( ) const MaxSegmentSizeBytes int64 = 32 * 1024 * 1024 +const MaxBatchSizeBytes int64 = 4 * 1024 * 1024 // commitTicket represents an ingestion task containing serialized record data // and a channel to communicate the result of the log write and fsync. @@ -27,9 +28,11 @@ type LogWriter struct { currentSizeBytes int64 ingestionChannel chan *commitTicket - shutdownSignal chan struct{} - workerWaitGroup sync.WaitGroup - closeOnce sync.Once + stateMu sync.RWMutex + isClosed bool + + workerWaitGroup sync.WaitGroup + closeOnce sync.Once } // NewLogWriter creates a new LogWriter instance, initializing the WAL directory, @@ -43,7 +46,6 @@ func NewLogWriter(directory string, nextSegmentID int) (*LogWriter, error) { directory: directory, currentSegmentID: nextSegmentID, ingestionChannel: make(chan *commitTicket, 10000), - shutdownSignal: make(chan struct{}), } if err := writer.rotateActiveFile(); err != nil { @@ -94,19 +96,17 @@ func (writer *LogWriter) Append(record *Record) error { resultChan: make(chan error, 1), } - slog.Debug("network thread: dropping record into ingestion channel", "frame_size", len(frame)) - - select { - case writer.ingestionChannel <- ticket: - select { - case err := <-ticket.resultChan: - return err - case <-writer.shutdownSignal: - return fmt.Errorf("database engine is currently shutting down, write rejected") - } - case <-writer.shutdownSignal: + writer.stateMu.RLock() + if writer.isClosed { + writer.stateMu.RUnlock() return fmt.Errorf("database engine is currently shutting down, write rejected") } + + slog.Debug("network thread: dropping record into ingestion channel", "frame_size", len(frame)) + writer.ingestionChannel <- ticket + writer.stateMu.RUnlock() + + return <-ticket.resultChan } // batchWorker runs in a background goroutine, receiving commit tickets from the @@ -117,53 +117,14 @@ func (writer *LogWriter) batchWorker() { var commitBatch []*commitTicket var writeBuffer []byte - for { - select { - case <-writer.shutdownSignal: - return - - case firstTicket := <-writer.ingestionChannel: - commitBatch = append(commitBatch[:0], firstTicket) - writeBuffer = append(writeBuffer[:0], firstTicket.frameData...) - - pendingWrites := len(writer.ingestionChannel) - for i := 0; i < pendingWrites; i++ { - ticket := <-writer.ingestionChannel - commitBatch = append(commitBatch, ticket) - writeBuffer = append(writeBuffer, ticket.frameData...) - } - - if writer.currentSizeBytes+int64(len(writeBuffer)) > MaxSegmentSizeBytes { - if err := writer.rotateActiveFile(); err != nil { - for _, ticket := range commitBatch { - ticket.resultChan <- err - } - continue - } - } - - slog.Debug("batch worker: executing group commit", - "batch_size", len(commitBatch), - "total_bytes", len(writeBuffer)) + for ticket := range writer.ingestionChannel { + commitBatch, writeBuffer = writer.gatherBatch(ticket, commitBatch, writeBuffer) - _, err := writer.activeFile.Write(writeBuffer) - if err == nil { - err = writer.activeFile.Sync() - } - if err == nil { - writer.currentSizeBytes += int64(len(writeBuffer)) - } + slog.Debug("batch worker: executing group commit", + "batch_size", len(commitBatch), + "total_bytes", len(writeBuffer)) - if err != nil { - slog.Debug("batch worker: fsync failed", "error", err) - } else { - slog.Debug("batch worker: fsync successful") - } - - for _, ticket := range commitBatch { - ticket.resultChan <- err - } - } + writer.writeAndSyncBatch(commitBatch, writeBuffer) } } @@ -172,8 +133,13 @@ func (writer *LogWriter) batchWorker() { func (writer *LogWriter) Close() error { var closeErr error writer.closeOnce.Do(func() { - close(writer.shutdownSignal) + writer.stateMu.Lock() + writer.isClosed = true + close(writer.ingestionChannel) + writer.stateMu.Unlock() + writer.workerWaitGroup.Wait() + if writer.activeFile != nil { if syncErr := writer.activeFile.Sync(); syncErr != nil { closeErr = syncErr @@ -185,3 +151,53 @@ func (writer *LogWriter) Close() error { }) return closeErr } + +// writeAndSyncBatch handles the disk I/O, segment rotation, and network thread +// notification for a gathered batch of records. +func (writer *LogWriter) writeAndSyncBatch(batch []*commitTicket, buffer []byte) { + if writer.currentSizeBytes+int64(len(buffer)) > MaxSegmentSizeBytes { + if err := writer.rotateActiveFile(); err != nil { + for _, ticket := range batch { + ticket.resultChan <- err + } + return + } + } + + _, err := writer.activeFile.Write(buffer) + if err == nil { + err = writer.activeFile.Sync() + } + + if err == nil { + writer.currentSizeBytes += int64(len(buffer)) + } else { + slog.Debug("batch worker: fsync/write failed", "error", err) + } + + for _, ticket := range batch { + ticket.resultChan <- err + } +} + +// gatherBatch pulls tickets from the ingestion channel up to the MaxBatchSizeBytes limit. +// It takes the first ticket yielded by the range loop and drains the remaining buffered tickets. +func (writer *LogWriter) gatherBatch(firstTicket *commitTicket, batch []*commitTicket, buffer []byte) ([]*commitTicket, []byte) { + batch = batch[:0] + buffer = buffer[:0] + + batch = append(batch, firstTicket) + buffer = append(buffer, firstTicket.frameData...) + + pendingWrites := len(writer.ingestionChannel) + for i := 0; i < pendingWrites; i++ { + if int64(len(buffer)) >= MaxBatchSizeBytes { + break + } + ticket := <-writer.ingestionChannel + batch = append(batch, ticket) + buffer = append(buffer, ticket.frameData...) + } + + return batch, buffer +} From d4b949d95ef2553f72817a153e23a41249323d86 Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Mon, 8 Jun 2026 13:41:45 +0530 Subject: [PATCH 13/18] feat: implement Write-Ahead Log recovery with segment replay and corruption handling --- internal/storage/wal/recovery.go | 45 ++++++++++++--------------- internal/storage/wal/recovery_test.go | 28 +++++++++++++---- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/internal/storage/wal/recovery.go b/internal/storage/wal/recovery.go index 82f13e0..8302aa3 100644 --- a/internal/storage/wal/recovery.go +++ b/internal/storage/wal/recovery.go @@ -11,17 +11,17 @@ import ( "sort" ) -// MemTable defines the minimal interface for in-memory storage engine components +// RecordConsumer defines the minimal interface for in-memory storage engine components // that can consume replayed WAL records during recovery. -type MemTable interface { +type RecordConsumer interface { Put(key, value []byte) error Delete(key []byte) error } // Replay scans the specified WAL directory, identifies all segment files // matching the *.wal pattern, and replays their logged operations onto the -// target MemTable. It returns the highest segment ID found or 1 if fresh. -func Replay(directory string, engine MemTable) (int, error) { +// target RecordConsumer. It returns the highest segment ID found or 1 if fresh. +func Replay(directory string, recordConsumer RecordConsumer) (int, error) { slog.Debug("starting WAL recovery sequence", "directory", directory) entries, err := os.ReadDir(directory) @@ -58,7 +58,7 @@ func Replay(directory string, engine MemTable) (int, error) { filePath := filepath.Join(directory, fileName) slog.Debug("replaying WAL segment", "segment_id", segmentID, "file", fileName) - if err := replayFile(filePath, engine); err != nil { + if err := replayFile(filePath, recordConsumer); err != nil { return 0, fmt.Errorf("critical failure while replaying segment %s: %w", fileName, err) } } @@ -72,9 +72,9 @@ func Replay(directory string, engine MemTable) (int, error) { } // replayFile opens a single WAL segment, reads it frame-by-frame, validates -// check-sums and frame sizes, and applies the operations onto the MemTable. +// check-sums and frame sizes, and applies the operations onto the RecordConsumer. // If it encounters corruption or a partial write, it truncates the segment. -func replayFile(filePath string, engine MemTable) (err error) { +func replayFile(filePath string, recordConsumer RecordConsumer) (err error) { file, err := os.OpenFile(filePath, os.O_RDWR, 0o644) if err != nil { return fmt.Errorf("unable to open WAL segment for reading: %w", err) @@ -89,6 +89,13 @@ func replayFile(filePath string, engine MemTable) (err error) { var recordsRecovered int headerBuffer := make([]byte, 8) + truncateAndSync := func() error { + if truncErr := file.Truncate(validBytes); truncErr != nil { + return truncErr + } + return file.Sync() + } + for { _, err := io.ReadFull(file, headerBuffer) if err != nil { @@ -103,10 +110,7 @@ func replayFile(filePath string, engine MemTable) (err error) { "file", filepath.Base(filePath), "valid_bytes", validBytes) - if truncErr := file.Truncate(validBytes); truncErr != nil { - return truncErr - } - return file.Sync() + return truncateAndSync() } return fmt.Errorf("unexpected disk error reading frame header: %w", err) } @@ -118,10 +122,7 @@ func replayFile(filePath string, engine MemTable) (err error) { "frame_size", totalFrameSizeBytes, "valid_bytes", validBytes) - if truncErr := file.Truncate(validBytes); truncErr != nil { - return truncErr - } - return file.Sync() + return truncateAndSync() } payloadSizeBytes := totalFrameSizeBytes - 8 @@ -132,10 +133,7 @@ func replayFile(filePath string, engine MemTable) (err error) { "file", filepath.Base(filePath), "valid_bytes", validBytes) - if truncErr := file.Truncate(validBytes); truncErr != nil { - return truncErr - } - return file.Sync() + return truncateAndSync() } fullFrame := make([]byte, 8+payloadSizeBytes) @@ -150,21 +148,18 @@ func replayFile(filePath string, engine MemTable) (err error) { "valid_bytes", validBytes, "error", err) - if truncErr := file.Truncate(validBytes); truncErr != nil { - return truncErr - } - return file.Sync() + return truncateAndSync() } return fmt.Errorf("failed to decode valid frame payload: %w", err) } switch record.Opcode { case OpcodePut: - if putErr := engine.Put(record.Key, record.Value); putErr != nil { + if putErr := recordConsumer.Put(record.Key, record.Value); putErr != nil { return fmt.Errorf("memtable rejected recovered put operation: %w", putErr) } case OpcodeDelete: - if delErr := engine.Delete(record.Key); delErr != nil { + if delErr := recordConsumer.Delete(record.Key); delErr != nil { return fmt.Errorf("memtable rejected recovered delete operation: %w", delErr) } } diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go index acf5fe8..0a2de2e 100644 --- a/internal/storage/wal/recovery_test.go +++ b/internal/storage/wal/recovery_test.go @@ -258,9 +258,9 @@ func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) - writeRecordsToFile(t, path, []*Record{ - {Opcode: OpcodePut, Key: []byte("ok"), Value: []byte("v")}, - }) + good := &Record{Opcode: OpcodePut, Key: []byte("ok"), Value: []byte("v")} + validBytes := good.Marshal() + writeRecordsToFile(t, path, []*Record{good}) f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) f.Write([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0xFF}) @@ -273,6 +273,14 @@ func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { if _, ok := mem.puts["ok"]; !ok { t.Error("valid record was not applied") } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } } // TestReplay_TruncatedPayload_TruncatesFile checks truncation when the frame payload is truncated. @@ -280,9 +288,9 @@ func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { dir := t.TempDir() path := segmentPath(dir, 1) - writeRecordsToFile(t, path, []*Record{ - {Opcode: OpcodePut, Key: []byte("safe"), Value: []byte("data")}, - }) + good := &Record{Opcode: OpcodePut, Key: []byte("safe"), Value: []byte("data")} + validBytes := good.Marshal() + writeRecordsToFile(t, path, []*Record{good}) fakeKey := []byte("payload-cut") fakeSizeBytes := uint32(8 + 3 + len(fakeKey) + 100) @@ -302,6 +310,14 @@ func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { if _, ok := mem.puts["safe"]; !ok { t.Error("valid record before truncation point was not applied") } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if info.Size() != int64(len(validBytes)) { + t.Errorf("file size after truncate = %d, want %d", info.Size(), len(validBytes)) + } } // TestReplay_EmptySegmentFile_NoError checks that an empty WAL segment is handled gracefully. From 655a012b62832f283c6d8fb5ded3099726d56011 Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Mon, 8 Jun 2026 13:49:52 +0530 Subject: [PATCH 14/18] feat: implement WAL writer with concurrent-safe append and automatic segment rotation --- internal/storage/wal/writer.go | 32 ++-- internal/storage/wal/writer_test.go | 277 ++++++++++++++++++++++++++++ 2 files changed, 293 insertions(+), 16 deletions(-) diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index e1b6991..afc2d14 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -28,7 +28,7 @@ type LogWriter struct { currentSizeBytes int64 ingestionChannel chan *commitTicket - stateMu sync.RWMutex + stateMutex sync.RWMutex isClosed bool workerWaitGroup sync.WaitGroup @@ -96,15 +96,15 @@ func (writer *LogWriter) Append(record *Record) error { resultChan: make(chan error, 1), } - writer.stateMu.RLock() + writer.stateMutex.RLock() if writer.isClosed { - writer.stateMu.RUnlock() + writer.stateMutex.RUnlock() return fmt.Errorf("database engine is currently shutting down, write rejected") } slog.Debug("network thread: dropping record into ingestion channel", "frame_size", len(frame)) writer.ingestionChannel <- ticket - writer.stateMu.RUnlock() + writer.stateMutex.RUnlock() return <-ticket.resultChan } @@ -133,10 +133,10 @@ func (writer *LogWriter) batchWorker() { func (writer *LogWriter) Close() error { var closeErr error writer.closeOnce.Do(func() { - writer.stateMu.Lock() + writer.stateMutex.Lock() writer.isClosed = true close(writer.ingestionChannel) - writer.stateMu.Unlock() + writer.stateMutex.Unlock() writer.workerWaitGroup.Wait() @@ -182,22 +182,22 @@ func (writer *LogWriter) writeAndSyncBatch(batch []*commitTicket, buffer []byte) // gatherBatch pulls tickets from the ingestion channel up to the MaxBatchSizeBytes limit. // It takes the first ticket yielded by the range loop and drains the remaining buffered tickets. -func (writer *LogWriter) gatherBatch(firstTicket *commitTicket, batch []*commitTicket, buffer []byte) ([]*commitTicket, []byte) { - batch = batch[:0] - buffer = buffer[:0] +func (writer *LogWriter) gatherBatch(firstTicket *commitTicket, inBatch []*commitTicket, inBuffer []byte) (outBatch []*commitTicket, outBuffer []byte) { + outBatch = inBatch[:0] + outBuffer = inBuffer[:0] - batch = append(batch, firstTicket) - buffer = append(buffer, firstTicket.frameData...) + outBatch = append(outBatch, firstTicket) + outBuffer = append(outBuffer, firstTicket.frameData...) pendingWrites := len(writer.ingestionChannel) - for i := 0; i < pendingWrites; i++ { - if int64(len(buffer)) >= MaxBatchSizeBytes { + for range pendingWrites { + if int64(len(outBuffer)) >= MaxBatchSizeBytes { break } ticket := <-writer.ingestionChannel - batch = append(batch, ticket) - buffer = append(buffer, ticket.frameData...) + outBatch = append(outBatch, ticket) + outBuffer = append(outBuffer, ticket.frameData...) } - return batch, buffer + return outBatch, outBuffer } diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go index 4413c36..3c1678d 100644 --- a/internal/storage/wal/writer_test.go +++ b/internal/storage/wal/writer_test.go @@ -473,3 +473,280 @@ func TestBatchWorker_GroupCommit_AllTicketsSignalled(t *testing.T) { } } } + +// TestMaxBatchSizeBytes_Is4MB checks the batch size limit constant. +func TestMaxBatchSizeBytes_Is4MB(t *testing.T) { + expected := int64(4 * 1024 * 1024) + if MaxBatchSizeBytes != expected { + t.Errorf("MaxBatchSizeBytes = %d, want %d (4 MiB)", MaxBatchSizeBytes, expected) + } +} + +// TestAppend_ConcurrentWithClose_NoHang verifies that goroutines calling Append +// while Close is invoked concurrently do not deadlock or panic. +func TestAppend_ConcurrentWithClose_NoHang(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numWriters = 50 + var wg sync.WaitGroup + + // Launch many concurrent writers. + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 20; j++ { + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("race-key-%d-%d", id, j)), + Value: []byte("v"), + } + _ = w.Append(r) // may return shutdown error, that's fine + } + }(i) + } + + // Close concurrently with writers. + done := make(chan struct{}) + go func() { + time.Sleep(1 * time.Millisecond) + w.Close() + close(done) + }() + + wg.Wait() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close or writers hung for more than 5 seconds") + } +} + +// TestAppend_AfterClose_ReturnsShutdownError checks the specific rejection path +// where Append observes isClosed=true under the read lock. +func TestAppend_AfterClose_ReturnsShutdownError(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + // Write one record to prove the writer works. + r := &Record{Opcode: OpcodePut, Key: []byte("before"), Value: []byte("close")} + if err := w.Append(r); err != nil { + t.Fatalf("Append before Close: %v", err) + } + + w.Close() + + // Multiple post-close appends should all return errors, never panic. + for i := 0; i < 10; i++ { + r := &Record{Opcode: OpcodePut, Key: []byte(fmt.Sprintf("post-%d", i)), Value: []byte("v")} + if err := w.Append(r); err == nil { + t.Errorf("Append[%d] after Close returned nil, expected error", i) + } + } +} + +// TestClose_ConcurrentCalls_NoPanic verifies that calling Close from multiple +// goroutines simultaneously does not panic or return inconsistent errors. +func TestClose_ConcurrentCalls_NoPanic(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + // Write some data first. + for i := 0; i < 10; i++ { + r := &Record{Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", i)), Value: []byte("v")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + } + + const numClosers = 10 + var wg sync.WaitGroup + for i := 0; i < numClosers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = w.Close() + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("concurrent Close calls hung for more than 5 seconds") + } +} + +// TestClose_DrainsInFlightTickets ensures that records already in the ingestion +// channel at the time of Close are still flushed and recoverable. +func TestClose_DrainsInFlightTickets(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numRecords = 50 + var wg sync.WaitGroup + errs := make([]error, numRecords) + + for i := 0; i < numRecords; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("drain-%04d", idx)), + Value: []byte("v"), + } + errs[idx] = w.Append(r) + }(i) + } + wg.Wait() + + if err := w.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Count how many succeeded. + var successCount int + for _, e := range errs { + if e == nil { + successCount++ + } + } + if successCount == 0 { + t.Fatal("no records were successfully appended") + } + + // Verify that all successfully appended records are recoverable. + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + + for i, e := range errs { + key := fmt.Sprintf("drain-%04d", i) + if e == nil { + if _, ok := mem.puts[key]; !ok { + t.Errorf("record %q was accepted by Append but not recovered by Replay", key) + } + } + } +} + +// TestBatchWorker_ExitsCleanly_WhenChannelClosed verifies the batchWorker goroutine +// terminates cleanly when the ingestion channel is closed (the new for-range pattern). +func TestBatchWorker_ExitsCleanly_WhenChannelClosed(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + // Append a record to prove the worker is running. + r := &Record{Opcode: OpcodePut, Key: []byte("alive"), Value: []byte("yes")} + if err := w.Append(r); err != nil { + t.Fatalf("Append: %v", err) + } + + // Close should cause the channel to close and the worker to exit. + done := make(chan struct{}) + go func() { + w.Close() + close(done) + }() + + select { + case <-done: + // Worker exited cleanly. + case <-time.After(3 * time.Second): + t.Fatal("batchWorker did not exit within 3 seconds after channel close") + } + + // Confirm the data is durable. + mem := newMockMemTable() + if _, err := Replay(dir, mem); err != nil { + t.Fatalf("Replay: %v", err) + } + if string(mem.puts["alive"]) != "yes" { + t.Error("record written before Close was not recovered") + } +} + +// TestAppend_ConcurrentWritesDuringClose_AllResolve verifies that every goroutine +// that calls Append gets a definitive result (success or error), even when Close +// is called concurrently. No goroutine should hang. +func TestAppend_ConcurrentWritesDuringClose_AllResolve(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + + const numWriters = 100 + results := make(chan error, numWriters) + + var wg sync.WaitGroup + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + r := &Record{ + Opcode: OpcodePut, + Key: []byte(fmt.Sprintf("resolve-%d", id)), + Value: []byte("v"), + } + results <- w.Append(r) + }(i) + } + + // Close while writers are still racing. + go func() { + time.Sleep(500 * time.Microsecond) + w.Close() + }() + + // All writers must eventually return. + allDone := make(chan struct{}) + go func() { + wg.Wait() + close(allDone) + }() + + select { + case <-allDone: + case <-time.After(5 * time.Second): + t.Fatal("not all Append goroutines resolved within 5 seconds") + } + + close(results) + var succeeded, failed int + for err := range results { + if err == nil { + succeeded++ + } else { + failed++ + } + } + t.Logf("concurrent close test: %d succeeded, %d rejected", succeeded, failed) + + if succeeded+failed != numWriters { + t.Errorf("expected %d total results, got %d", numWriters, succeeded+failed) + } +} From 3b7d5f9394249277b237f7c2513421901c98348e Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Wed, 10 Jun 2026 07:23:58 +0530 Subject: [PATCH 15/18] Refactor unit tests for WAL recovery, writer, and record serialization logic --- internal/storage/wal/record_test.go | 46 ++++++++--------- internal/storage/wal/recovery_test.go | 74 +++++++++++++-------------- internal/storage/wal/writer_test.go | 12 ++--- 3 files changed, 66 insertions(+), 66 deletions(-) diff --git a/internal/storage/wal/record_test.go b/internal/storage/wal/record_test.go index 1f79296..a5aebe8 100644 --- a/internal/storage/wal/record_test.go +++ b/internal/storage/wal/record_test.go @@ -19,30 +19,30 @@ func TestMarshal_FrameLayout(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} frame := r.Marshal() - wantSize := 8 + 3 + len(r.Key) + len(r.Value) + wantSize := fixedHeaderSize + len(r.Key) + len(r.Value) if len(frame) != wantSize { t.Fatalf("frame length = %d, want %d", len(frame), wantSize) } - storedSize := binary.LittleEndian.Uint32(frame[4:8]) + storedSize := binary.LittleEndian.Uint32(frame[frameSizeOffset:opcodeOffset]) if int(storedSize) != wantSize { t.Errorf("stored size field = %d, want %d", storedSize, wantSize) } - if frame[8] != OpcodePut { - t.Errorf("opcode byte = %d, want %d", frame[8], OpcodePut) + if frame[opcodeOffset] != OpcodePut { + t.Errorf("opcode byte = %d, want %d", frame[opcodeOffset], OpcodePut) } - storedKeyLen := binary.LittleEndian.Uint16(frame[9:11]) + storedKeyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]) if int(storedKeyLen) != len(r.Key) { t.Errorf("stored key length = %d, want %d", storedKeyLen, len(r.Key)) } - if !bytes.Equal(frame[11:11+len(r.Key)], r.Key) { + if !bytes.Equal(frame[keyOffset:keyOffset+len(r.Key)], r.Key) { t.Error("key bytes mismatch in frame") } - if !bytes.Equal(frame[11+len(r.Key):], r.Value) { + if !bytes.Equal(frame[keyOffset+len(r.Key):], r.Value) { t.Error("value bytes mismatch in frame") } } @@ -53,9 +53,9 @@ func TestMarshal_CRCCoversPayload(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("k"), Value: nil} frame := r.Marshal() - storedCRC := binary.LittleEndian.Uint32(frame[0:4]) - if storedCRC != crc32.ChecksumIEEE(frame[4:]) { - t.Errorf("CRC mismatch: stored=%d calculated=%d", storedCRC, crc32.ChecksumIEEE(frame[4:])) + storedCRC := binary.LittleEndian.Uint32(frame[checksumOffset:frameSizeOffset]) + if storedCRC != crc32.ChecksumIEEE(frame[frameSizeOffset:]) { + t.Errorf("CRC mismatch: stored=%d calculated=%d", storedCRC, crc32.ChecksumIEEE(frame[frameSizeOffset:])) } } @@ -63,8 +63,8 @@ func TestMarshal_CRCCoversPayload(t *testing.T) { func TestMarshal_OpcodeDelete(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} frame := r.Marshal() - if frame[8] != OpcodeDelete { - t.Errorf("opcode = %d, want OpcodeDelete (%d)", frame[8], OpcodeDelete) + if frame[opcodeOffset] != OpcodeDelete { + t.Errorf("opcode = %d, want OpcodeDelete (%d)", frame[opcodeOffset], OpcodeDelete) } } @@ -80,7 +80,7 @@ func TestMarshal_ZeroLengthKey(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { frame := (&Record{Opcode: OpcodePut, Key: tc.key, Value: []byte("v")}).Marshal() - if keyLen := binary.LittleEndian.Uint16(frame[9:11]); keyLen != 0 { + if keyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]); keyLen != 0 { t.Errorf("key length = %d, want 0", keyLen) } }) @@ -100,7 +100,7 @@ func TestMarshal_ZeroLengthValue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: tc.value} frame := r.Marshal() - wantSize := 8 + 3 + len(r.Key) + wantSize := fixedHeaderSize + len(r.Key) if len(frame) != wantSize { t.Errorf("frame length = %d, want %d", len(frame), wantSize) } @@ -114,12 +114,12 @@ func TestMarshal_LargePayload(t *testing.T) { value := bytes.Repeat([]byte("v"), 4096) r := &Record{Opcode: OpcodePut, Key: key, Value: value} frame := r.Marshal() - wantSize := 8 + 3 + len(key) + len(value) + wantSize := fixedHeaderSize + len(key) + len(value) if len(frame) != wantSize { t.Errorf("frame size = %d, want %d", len(frame), wantSize) } - storedCRC := binary.LittleEndian.Uint32(frame[0:4]) - if storedCRC != crc32.ChecksumIEEE(frame[4:]) { + storedCRC := binary.LittleEndian.Uint32(frame[checksumOffset:frameSizeOffset]) + if storedCRC != crc32.ChecksumIEEE(frame[frameSizeOffset:]) { t.Error("CRC invalid for large payload") } } @@ -223,9 +223,9 @@ func TestUnmarshal_CRCCorruption(t *testing.T) { name string corruptFn func([]byte) }{ - {"payload byte", func(f []byte) { f[12] ^= 0x01 }}, - {"crc field", func(f []byte) { f[0] ^= 0xFF }}, - {"size field", func(f []byte) { binary.LittleEndian.PutUint32(f[4:8], 9999) }}, + {"payload byte", func(f []byte) { f[keyOffset+1] ^= 0x01 }}, + {"crc field", func(f []byte) { f[checksumOffset] ^= 0xFF }}, + {"size field", func(f []byte) { binary.LittleEndian.PutUint32(f[frameSizeOffset:opcodeOffset], 9999) }}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -245,9 +245,9 @@ func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("ab"), Value: []byte("v")} frame := r.Marshal() - binary.LittleEndian.PutUint16(frame[9:11], 65535) - newCRC := crc32.ChecksumIEEE(frame[4:]) - binary.LittleEndian.PutUint32(frame[0:4], newCRC) + binary.LittleEndian.PutUint16(frame[keyLengthOffset:keyOffset], 65535) + newCRC := crc32.ChecksumIEEE(frame[frameSizeOffset:]) + binary.LittleEndian.PutUint32(frame[checksumOffset:frameSizeOffset], newCRC) _, err := UnmarshalRecord(frame) if !errors.Is(err, ErrTruncated) { diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go index 0a2de2e..14c9e46 100644 --- a/internal/storage/wal/recovery_test.go +++ b/internal/storage/wal/recovery_test.go @@ -9,22 +9,22 @@ import ( "testing" ) -// mockMemTable implements the MemTable interface for tracking replayed actions +// mockRecordConsumer implements the RecordConsumer interface for tracking replayed actions // and simulating write failures during recovery tests. -type mockMemTable struct { +type mockRecordConsumer struct { puts map[string][]byte deletes []string putErr error delErr error } -// newMockMemTable constructs an empty mockMemTable. -func newMockMemTable() *mockMemTable { - return &mockMemTable{puts: make(map[string][]byte)} +// newMockRecordConsumer constructs an empty mockRecordConsumer. +func newMockRecordConsumer() *mockRecordConsumer { + return &mockRecordConsumer{puts: make(map[string][]byte)} } // Put records a put operation or returns an injected failure error. -func (m *mockMemTable) Put(key, value []byte) error { +func (m *mockRecordConsumer) Put(key, value []byte) error { if m.putErr != nil { return m.putErr } @@ -33,7 +33,7 @@ func (m *mockMemTable) Put(key, value []byte) error { } // Delete records a delete operation or returns an injected failure error. -func (m *mockMemTable) Delete(key []byte) error { +func (m *mockRecordConsumer) Delete(key []byte) error { if m.delErr != nil { return m.delErr } @@ -65,7 +65,7 @@ func segmentPath(dir string, id int) string { // correctly if the WAL directory does not exist. func TestReplay_NonExistentDirectory_ReturnsFreshSegmentID(t *testing.T) { dir := filepath.Join(t.TempDir(), "no-such-wal") - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { @@ -82,7 +82,7 @@ func TestReplay_NonExistentDirectory_ReturnsFreshSegmentID(t *testing.T) { // TestReplay_EmptyDirectory_ReturnsFreshSegmentID checks recovery on an empty directory. func TestReplay_EmptyDirectory_ReturnsFreshSegmentID(t *testing.T) { dir := t.TempDir() - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { @@ -101,7 +101,7 @@ func TestReplay_NonWALFilesIgnored(t *testing.T) { t.Fatal(err) } } - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("expected no error, got %v", err) @@ -121,7 +121,7 @@ func TestReplay_SingleSegment_AllPuts(t *testing.T) { } writeRecordsToFile(t, segmentPath(dir, 1), records) - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -144,7 +144,7 @@ func TestReplay_SingleSegment_Deletes(t *testing.T) { {Opcode: OpcodeDelete, Key: []byte("x")}, }) - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -167,7 +167,7 @@ func TestReplay_MultipleSegments_ReplayedInOrder(t *testing.T) { {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("seg2")}, }) - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -188,7 +188,7 @@ func TestReplay_ReturnsHighestSegmentID(t *testing.T) { {Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", id)), Value: []byte("v")}, }) } - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -205,7 +205,7 @@ func TestReplay_UnknownOpcode_Ignored(t *testing.T) { {Opcode: 99, Key: []byte("k"), Value: []byte("v")}, }) - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Errorf("unexpected error for unknown opcode: %v", err) } @@ -232,7 +232,7 @@ func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { f.Write(badFrame) f.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -266,7 +266,7 @@ func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { f.Write([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0xFF}) f.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -293,17 +293,17 @@ func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { writeRecordsToFile(t, path, []*Record{good}) fakeKey := []byte("payload-cut") - fakeSizeBytes := uint32(8 + 3 + len(fakeKey) + 100) - hdr := make([]byte, 8) - binary.LittleEndian.PutUint32(hdr[4:8], fakeSizeBytes) - binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) + fakeSizeBytes := uint32(fixedHeaderSize + len(fakeKey) + 100) + hdr := make([]byte, checksumSize+frameSizeSize) + binary.LittleEndian.PutUint32(hdr[frameSizeOffset:frameSizeOffset+frameSizeSize], fakeSizeBytes) + binary.LittleEndian.PutUint32(hdr[checksumOffset:checksumOffset+checksumSize], crc32.ChecksumIEEE(hdr[frameSizeOffset:])) f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) f.Write(hdr) f.Write(fakeKey[:3]) f.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -325,7 +325,7 @@ func TestReplay_EmptySegmentFile_NoError(t *testing.T) { dir := t.TempDir() os.WriteFile(segmentPath(dir, 1), []byte{}, 0644) - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error for empty WAL file: %v", err) @@ -342,7 +342,7 @@ func TestReplay_MemTablePutError_PropagatesError(t *testing.T) { {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")}, }) - mem := newMockMemTable() + mem := newMockRecordConsumer() mem.putErr = fmt.Errorf("memtable full") if _, err := Replay(dir, mem); err == nil { @@ -357,7 +357,7 @@ func TestReplay_MemTableDeleteError_PropagatesError(t *testing.T) { {Opcode: OpcodeDelete, Key: []byte("k")}, }) - mem := newMockMemTable() + mem := newMockRecordConsumer() mem.delErr = fmt.Errorf("read-only memtable") if _, err := Replay(dir, mem); err == nil { @@ -368,7 +368,7 @@ func TestReplay_MemTableDeleteError_PropagatesError(t *testing.T) { // TestReplay_MixedPutsAndDeletes_CorrectOrder verifies interleaved Puts/Deletes replay in correct order. func TestReplay_MixedPutsAndDeletes_CorrectOrder(t *testing.T) { dir := t.TempDir() - mem := newMockMemTable() + mem := newMockRecordConsumer() writeRecordsToFile(t, segmentPath(dir, 1), []*Record{ {Opcode: OpcodePut, Key: []byte("a"), Value: []byte("1")}, @@ -409,7 +409,7 @@ func TestReplay_SegmentsSortedNumerically(t *testing.T) { }) } - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -433,7 +433,7 @@ func TestReplay_SubdirectoriesAreIgnored(t *testing.T) { {Opcode: OpcodePut, Key: []byte("k"), Value: []byte("v")}, }) - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -452,9 +452,9 @@ func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { validBytes := good.Marshal() writeRecordsToFile(t, path, []*Record{good}) - hdr := make([]byte, 8) - binary.LittleEndian.PutUint32(hdr[4:8], 7) - binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) + hdr := make([]byte, checksumSize+frameSizeSize) + binary.LittleEndian.PutUint32(hdr[frameSizeOffset:frameSizeOffset+frameSizeSize], 7) + binary.LittleEndian.PutUint32(hdr[checksumOffset:checksumOffset+checksumSize], crc32.ChecksumIEEE(hdr[frameSizeOffset:])) f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) if err != nil { @@ -463,7 +463,7 @@ func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { f.Write(hdr) f.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -489,15 +489,15 @@ func TestReplay_InvalidFrameSize_TooLarge_TruncatesFile(t *testing.T) { validBytes := good.Marshal() writeRecordsToFile(t, path, []*Record{good}) - hdr := make([]byte, 8) - binary.LittleEndian.PutUint32(hdr[4:8], 129*1024*1024) - binary.LittleEndian.PutUint32(hdr[0:4], crc32.ChecksumIEEE(hdr[4:])) + hdr := make([]byte, checksumSize+frameSizeSize) + binary.LittleEndian.PutUint32(hdr[frameSizeOffset:frameSizeOffset+frameSizeSize], 129*1024*1024) + binary.LittleEndian.PutUint32(hdr[checksumOffset:checksumOffset+checksumSize], crc32.ChecksumIEEE(hdr[frameSizeOffset:])) f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) f.Write(hdr) f.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -531,7 +531,7 @@ func TestReplay_MalformedWALFilename_Skipped(t *testing.T) { f.Write(bad.Marshal()) f.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() nextID, err := Replay(dir, mem) if err != nil { t.Fatalf("unexpected error replaying: %v", err) diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go index 3c1678d..34d1f08 100644 --- a/internal/storage/wal/writer_test.go +++ b/internal/storage/wal/writer_test.go @@ -114,7 +114,7 @@ func TestAppend_MultipleRecords_AllWritten(t *testing.T) { t.Fatalf("Close: %v", err) } - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("Replay: %v", err) } @@ -146,7 +146,7 @@ func TestAppend_RecordRoundtrip_ViaReplay(t *testing.T) { } w.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("Replay: %v", err) } @@ -360,7 +360,7 @@ func TestAppend_ConcurrentWrites_AllRecordsRecoverable(t *testing.T) { wg.Wait() w.Close() - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("Replay: %v", err) } @@ -424,7 +424,7 @@ func TestClose_SyncsDataToDisk(t *testing.T) { t.Fatalf("Close: %v", err) } - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("Replay: %v", err) } @@ -635,7 +635,7 @@ func TestClose_DrainsInFlightTickets(t *testing.T) { } // Verify that all successfully appended records are recoverable. - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("Replay: %v", err) } @@ -680,7 +680,7 @@ func TestBatchWorker_ExitsCleanly_WhenChannelClosed(t *testing.T) { } // Confirm the data is durable. - mem := newMockMemTable() + mem := newMockRecordConsumer() if _, err := Replay(dir, mem); err != nil { t.Fatalf("Replay: %v", err) } From acaee83eeb41b17282a1a74535a007a9a116ed13 Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Wed, 10 Jun 2026 08:13:11 +0530 Subject: [PATCH 16/18] Add ErrKeyTooLarge and ErrFrameTooLarge --- internal/storage/wal/errors.go | 30 +++++++++++- internal/storage/wal/record.go | 17 ++++++- internal/storage/wal/record_test.go | 67 +++++++++++++++++++++------ internal/storage/wal/recovery_test.go | 16 +++---- internal/storage/wal/writer.go | 5 +- 5 files changed, 110 insertions(+), 25 deletions(-) diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go index b19c51b..f2276af 100644 --- a/internal/storage/wal/errors.go +++ b/internal/storage/wal/errors.go @@ -2,7 +2,10 @@ // It supports log appending, rotation, record serialization, and replay recovery. package wal -import "errors" +import ( + "errors" + "math" +) var ( // ErrInvalidCRC is returned when a record's checksum does not match its payload. @@ -13,4 +16,29 @@ var ( // ErrEmptyKey is returned when attempting to write a log record with a zero-length key. ErrEmptyKey = errors.New("wal record rejected: key must not be empty") + + // ErrKeyTooLarge is returned when the key exceeds the maximum representable + // length (math.MaxUint16 bytes) in the on-disk frame format. + ErrKeyTooLarge = errors.New("wal record rejected: key length exceeds maximum of " + uitoa(math.MaxUint16) + " bytes") + + // ErrFrameTooLarge is returned when the total serialized frame size exceeds + // the maximum representable size (math.MaxUint32 bytes) in the on-disk format. + ErrFrameTooLarge = errors.New("wal record rejected: frame size exceeds maximum of " + uitoa(math.MaxUint32) + " bytes") ) + +// uitoa converts an unsigned integer to its decimal string representation. +// Used to embed numeric limits in error sentinel messages at init time without +// depending on strconv or fmt. +func uitoa(val uint64) string { + if val == 0 { + return "0" + } + var buf [20]byte // max uint64 is 20 digits + i := len(buf) + for val > 0 { + i-- + buf[i] = byte(val%10) + '0' + val /= 10 + } + return string(buf[i:]) +} diff --git a/internal/storage/wal/record.go b/internal/storage/wal/record.go index c49d2e1..00b34fb 100644 --- a/internal/storage/wal/record.go +++ b/internal/storage/wal/record.go @@ -3,6 +3,7 @@ package wal import ( "encoding/binary" "hash/crc32" + "math" ) const ( @@ -49,12 +50,24 @@ const ( // +-------------+-------------+----------+------------+-----------+-----------+ // // Note: The Checksum (CRC32) covers all bytes starting from the Frame Size. -func (record *Record) Marshal() []byte { +// +// Marshal returns ErrKeyTooLarge if the key exceeds math.MaxUint16 bytes, or +// ErrFrameTooLarge if the total frame size exceeds math.MaxUint32 bytes, since +// these sizes cannot be represented in the on-disk field widths. +func (record *Record) Marshal() ([]byte, error) { keyLen := len(record.Key) valLen := len(record.Value) + if keyLen > math.MaxUint16 { + return nil, ErrKeyTooLarge + } + totalFrameSizeBytes := fixedHeaderSize + keyLen + valLen + if totalFrameSizeBytes > math.MaxUint32 { + return nil, ErrFrameTooLarge + } + frameBuffer := make([]byte, totalFrameSizeBytes) binary.LittleEndian.PutUint32(frameBuffer[frameSizeOffset:opcodeOffset], uint32(totalFrameSizeBytes)) @@ -68,7 +81,7 @@ func (record *Record) Marshal() []byte { calculatedChecksum := crc32.ChecksumIEEE(frameBuffer[frameSizeOffset:]) binary.LittleEndian.PutUint32(frameBuffer[checksumOffset:frameSizeOffset], calculatedChecksum) - return frameBuffer + return frameBuffer, nil } // UnmarshalRecord deserializes a raw binary frame and reconstructs the original diff --git a/internal/storage/wal/record_test.go b/internal/storage/wal/record_test.go index a5aebe8..9e72300 100644 --- a/internal/storage/wal/record_test.go +++ b/internal/storage/wal/record_test.go @@ -5,19 +5,34 @@ import ( "encoding/binary" "errors" "hash/crc32" + "math" "testing" ) +// mustMarshal is a test helper that calls Marshal and fails the test on error. +func mustMarshal(t *testing.T, r *Record) []byte { + t.Helper() + frame, err := r.Marshal() + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + return frame +} + // buildValidFrame is a helper that constructs a valid binary frame from a Record. func buildValidFrame(r *Record) []byte { - return r.Marshal() + frame, err := r.Marshal() + if err != nil { + panic("buildValidFrame: " + err.Error()) + } + return frame } // TestMarshal_FrameLayout tests that the layout of the marshaled frame meets // specification requirements (proper size, opcode location, and key/value placement). func TestMarshal_FrameLayout(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("hello"), Value: []byte("world")} - frame := r.Marshal() + frame := mustMarshal(t, r) wantSize := fixedHeaderSize + len(r.Key) + len(r.Value) if len(frame) != wantSize { @@ -51,7 +66,7 @@ func TestMarshal_FrameLayout(t *testing.T) { // correctly over the rest of the frame payload. func TestMarshal_CRCCoversPayload(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("k"), Value: nil} - frame := r.Marshal() + frame := mustMarshal(t, r) storedCRC := binary.LittleEndian.Uint32(frame[checksumOffset:frameSizeOffset]) if storedCRC != crc32.ChecksumIEEE(frame[frameSizeOffset:]) { @@ -62,7 +77,7 @@ func TestMarshal_CRCCoversPayload(t *testing.T) { // TestMarshal_OpcodeDelete verifies that deletion records marshal with OpcodeDelete. func TestMarshal_OpcodeDelete(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} - frame := r.Marshal() + frame := mustMarshal(t, r) if frame[opcodeOffset] != OpcodeDelete { t.Errorf("opcode = %d, want OpcodeDelete (%d)", frame[opcodeOffset], OpcodeDelete) } @@ -79,7 +94,7 @@ func TestMarshal_ZeroLengthKey(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - frame := (&Record{Opcode: OpcodePut, Key: tc.key, Value: []byte("v")}).Marshal() + frame := mustMarshal(t, &Record{Opcode: OpcodePut, Key: tc.key, Value: []byte("v")}) if keyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]); keyLen != 0 { t.Errorf("key length = %d, want 0", keyLen) } @@ -99,7 +114,7 @@ func TestMarshal_ZeroLengthValue(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: tc.value} - frame := r.Marshal() + frame := mustMarshal(t, r) wantSize := fixedHeaderSize + len(r.Key) if len(frame) != wantSize { t.Errorf("frame length = %d, want %d", len(frame), wantSize) @@ -113,7 +128,7 @@ func TestMarshal_LargePayload(t *testing.T) { key := bytes.Repeat([]byte("k"), 1024) value := bytes.Repeat([]byte("v"), 4096) r := &Record{Opcode: OpcodePut, Key: key, Value: value} - frame := r.Marshal() + frame := mustMarshal(t, r) wantSize := fixedHeaderSize + len(key) + len(value) if len(frame) != wantSize { t.Errorf("frame size = %d, want %d", len(frame), wantSize) @@ -130,7 +145,7 @@ func TestMarshal_BinaryKeyAndValue(t *testing.T) { key := []byte{0x00, 0xFF, 0x7F, 0x80} value := []byte{0x01, 0x02, 0x03} r := &Record{Opcode: OpcodePut, Key: key, Value: value} - frame := r.Marshal() + frame := mustMarshal(t, r) recovered, err := UnmarshalRecord(frame) if err != nil { @@ -147,11 +162,37 @@ func TestMarshal_BinaryKeyAndValue(t *testing.T) { // TestMarshal_Idempotent verifies that Marshal output is identical for successive calls. func TestMarshal_Idempotent(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("idempotent"), Value: []byte("yes")} - if !bytes.Equal(r.Marshal(), r.Marshal()) { + frame1 := mustMarshal(t, r) + frame2 := mustMarshal(t, r) + if !bytes.Equal(frame1, frame2) { t.Error("Marshal is not idempotent for the same record") } } +// TestMarshal_KeyTooLarge verifies that keys exceeding math.MaxUint16 bytes are rejected. +func TestMarshal_KeyTooLarge(t *testing.T) { + oversizedKey := make([]byte, math.MaxUint16+1) + r := &Record{Opcode: OpcodePut, Key: oversizedKey, Value: []byte("v")} + _, err := r.Marshal() + if !errors.Is(err, ErrKeyTooLarge) { + t.Errorf("expected ErrKeyTooLarge, got %v", err) + } +} + +// TestMarshal_MaxKeyLength verifies that a key of exactly math.MaxUint16 bytes is accepted. +func TestMarshal_MaxKeyLength(t *testing.T) { + maxKey := make([]byte, math.MaxUint16) + r := &Record{Opcode: OpcodePut, Key: maxKey, Value: nil} + frame, err := r.Marshal() + if err != nil { + t.Fatalf("Marshal rejected max-length key: %v", err) + } + storedKeyLen := binary.LittleEndian.Uint16(frame[keyLengthOffset:keyOffset]) + if storedKeyLen != math.MaxUint16 { + t.Errorf("stored key length = %d, want %d", storedKeyLen, math.MaxUint16) + } +} + // TestUnmarshal_RoundTrip_Put checks that a serialized Put record can be accurately reconstructed. func TestUnmarshal_RoundTrip_Put(t *testing.T) { original := &Record{Opcode: OpcodePut, Key: []byte("name"), Value: []byte("penguin")} @@ -207,7 +248,7 @@ func TestUnmarshal_TruncatedFrame_TooShort(t *testing.T) { // TestUnmarshal_MinimalFrame_EmptyKeyNoValue verifies parsing of minimally sized frames. func TestUnmarshal_MinimalFrame_EmptyKeyNoValue(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte{}, Value: nil} - frame := r.Marshal() + frame := mustMarshal(t, r) recovered, err := UnmarshalRecord(frame) if err != nil { t.Fatalf("unexpected error for minimal frame: %v", err) @@ -230,7 +271,7 @@ func TestUnmarshal_CRCCorruption(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("key"), Value: []byte("val")} - frame := r.Marshal() + frame := mustMarshal(t, r) tc.corruptFn(frame) _, err := UnmarshalRecord(frame) if !errors.Is(err, ErrInvalidCRC) { @@ -243,7 +284,7 @@ func TestUnmarshal_CRCCorruption(t *testing.T) { // TestUnmarshal_KeyLengthExceedsFrame checks that key lengths exceeding frame limits are rejected. func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { r := &Record{Opcode: OpcodePut, Key: []byte("ab"), Value: []byte("v")} - frame := r.Marshal() + frame := mustMarshal(t, r) binary.LittleEndian.PutUint16(frame[keyLengthOffset:keyOffset], 65535) newCRC := crc32.ChecksumIEEE(frame[frameSizeOffset:]) @@ -258,7 +299,7 @@ func TestUnmarshal_KeyLengthExceedsFrame(t *testing.T) { // TestUnmarshal_NilValueForDeleteRecord checks that deleted records correctly parse nil values. func TestUnmarshal_NilValueForDeleteRecord(t *testing.T) { r := &Record{Opcode: OpcodeDelete, Key: []byte("mykey"), Value: nil} - frame := r.Marshal() + frame := mustMarshal(t, r) recovered, err := UnmarshalRecord(frame) if err != nil { diff --git a/internal/storage/wal/recovery_test.go b/internal/storage/wal/recovery_test.go index 14c9e46..acfbfc4 100644 --- a/internal/storage/wal/recovery_test.go +++ b/internal/storage/wal/recovery_test.go @@ -50,7 +50,7 @@ func writeRecordsToFile(t *testing.T, path string, records []*Record) { } defer f.Close() for _, r := range records { - if _, err := f.Write(r.Marshal()); err != nil { + if _, err := f.Write(mustMarshal(t, r)); err != nil { t.Fatalf("writeRecordsToFile write: %v", err) } } @@ -220,10 +220,10 @@ func TestReplay_CorruptedCRC_TruncatesFile(t *testing.T) { path := segmentPath(dir, 1) good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} - validBytes := good.Marshal() + validBytes := mustMarshal(t, good) writeRecordsToFile(t, path, []*Record{good}) - badFrame := (&Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("x")}).Marshal() + badFrame := mustMarshal(t, &Record{Opcode: OpcodePut, Key: []byte("bad"), Value: []byte("x")}) badFrame[12] ^= 0xFF f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) if err != nil { @@ -259,7 +259,7 @@ func TestReplay_TruncatedHeader_TruncatesFile(t *testing.T) { path := segmentPath(dir, 1) good := &Record{Opcode: OpcodePut, Key: []byte("ok"), Value: []byte("v")} - validBytes := good.Marshal() + validBytes := mustMarshal(t, good) writeRecordsToFile(t, path, []*Record{good}) f, _ := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0644) @@ -289,7 +289,7 @@ func TestReplay_TruncatedPayload_TruncatesFile(t *testing.T) { path := segmentPath(dir, 1) good := &Record{Opcode: OpcodePut, Key: []byte("safe"), Value: []byte("data")} - validBytes := good.Marshal() + validBytes := mustMarshal(t, good) writeRecordsToFile(t, path, []*Record{good}) fakeKey := []byte("payload-cut") @@ -449,7 +449,7 @@ func TestReplay_InvalidFrameSize_TooSmall_TruncatesFile(t *testing.T) { path := segmentPath(dir, 1) good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} - validBytes := good.Marshal() + validBytes := mustMarshal(t, good) writeRecordsToFile(t, path, []*Record{good}) hdr := make([]byte, checksumSize+frameSizeSize) @@ -486,7 +486,7 @@ func TestReplay_InvalidFrameSize_TooLarge_TruncatesFile(t *testing.T) { path := segmentPath(dir, 1) good := &Record{Opcode: OpcodePut, Key: []byte("good"), Value: []byte("val")} - validBytes := good.Marshal() + validBytes := mustMarshal(t, good) writeRecordsToFile(t, path, []*Record{good}) hdr := make([]byte, checksumSize+frameSizeSize) @@ -528,7 +528,7 @@ func TestReplay_MalformedWALFilename_Skipped(t *testing.T) { if err != nil { t.Fatal(err) } - f.Write(bad.Marshal()) + f.Write(mustMarshal(t, bad)) f.Close() mem := newMockRecordConsumer() diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index afc2d14..319e70f 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -89,7 +89,10 @@ func (writer *LogWriter) Append(record *Record) error { return ErrEmptyKey } - frame := record.Marshal() + frame, err := record.Marshal() + if err != nil { + return err + } ticket := &commitTicket{ frameData: frame, From 99413c31aa525fefd80243bf5334afe5adf06712 Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Wed, 10 Jun 2026 08:59:58 +0530 Subject: [PATCH 17/18] Added ErrInvalidOpcode sentinel error --- internal/storage/wal/errors.go | 5 +++++ internal/storage/wal/writer.go | 4 ++++ internal/storage/wal/writer_test.go | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+) diff --git a/internal/storage/wal/errors.go b/internal/storage/wal/errors.go index f2276af..eb71327 100644 --- a/internal/storage/wal/errors.go +++ b/internal/storage/wal/errors.go @@ -17,6 +17,11 @@ var ( // ErrEmptyKey is returned when attempting to write a log record with a zero-length key. ErrEmptyKey = errors.New("wal record rejected: key must not be empty") + // ErrInvalidOpcode is returned when a record carries an opcode that is not + // recognized by the WAL format. Persisting such a record would succeed but + // the entry would be silently skipped during recovery replay. + ErrInvalidOpcode = errors.New("wal record rejected: unrecognized opcode") + // ErrKeyTooLarge is returned when the key exceeds the maximum representable // length (math.MaxUint16 bytes) in the on-disk frame format. ErrKeyTooLarge = errors.New("wal record rejected: key length exceeds maximum of " + uitoa(math.MaxUint16) + " bytes") diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 319e70f..7c4ce5e 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -89,6 +89,10 @@ func (writer *LogWriter) Append(record *Record) error { return ErrEmptyKey } + if record.Opcode != OpcodePut && record.Opcode != OpcodeDelete { + return ErrInvalidOpcode + } + frame, err := record.Marshal() if err != nil { return err diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go index 34d1f08..995251e 100644 --- a/internal/storage/wal/writer_test.go +++ b/internal/storage/wal/writer_test.go @@ -215,6 +215,24 @@ func TestAppend_EmptyValue_Allowed(t *testing.T) { } } +// TestAppend_InvalidOpcode_Rejected verifies that records with unrecognized opcodes +// are rejected with ErrInvalidOpcode instead of being silently persisted. +func TestAppend_InvalidOpcode_Rejected(t *testing.T) { + dir := t.TempDir() + w, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter: %v", err) + } + defer w.Close() + + for _, opcode := range []uint8{2, 99, 255} { + r := &Record{Opcode: opcode, Key: []byte("k"), Value: []byte("v")} + if err := w.Append(r); !errors.Is(err, ErrInvalidOpcode) { + t.Errorf("opcode %d: expected ErrInvalidOpcode, got %v", opcode, err) + } + } +} + // TestRotation_NewSegmentCreatedAfterSizeExceeded checks that segment file rotates on limit overrun. func TestRotation_NewSegmentCreatedAfterSizeExceeded(t *testing.T) { dir := t.TempDir() From 05deb0c6c012dfb68c96eef94623b0bb3de9310c Mon Sep 17 00:00:00 2001 From: Saptak Manna Date: Wed, 10 Jun 2026 09:03:45 +0530 Subject: [PATCH 18/18] Seed currentSizeBytes from disk when opening a segment --- internal/storage/wal/writer.go | 9 +++++- internal/storage/wal/writer_test.go | 45 +++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/internal/storage/wal/writer.go b/internal/storage/wal/writer.go index 7c4ce5e..3a242d4 100644 --- a/internal/storage/wal/writer.go +++ b/internal/storage/wal/writer.go @@ -78,7 +78,14 @@ func (writer *LogWriter) rotateActiveFile() error { } writer.activeFile = file - writer.currentSizeBytes = 0 + + info, err := file.Stat() + if err != nil { + file.Close() + return fmt.Errorf("failed to stat WAL segment %s: %w", segmentPath, err) + } + writer.currentSizeBytes = info.Size() + return nil } diff --git a/internal/storage/wal/writer_test.go b/internal/storage/wal/writer_test.go index 995251e..24ef4ca 100644 --- a/internal/storage/wal/writer_test.go +++ b/internal/storage/wal/writer_test.go @@ -309,6 +309,51 @@ func TestRotation_SizeResetAfterRotation(t *testing.T) { } } +// TestRotation_ReopenedSegment_SizeAccountedFor verifies that when a writer +// reopens an existing segment (e.g. after recovery), it seeds currentSizeBytes +// from the file's actual size so the rotation threshold isn't silently bypassed. +func TestRotation_ReopenedSegment_SizeAccountedFor(t *testing.T) { + dir := t.TempDir() + + // Phase 1: Write data to segment 1, then close. + w1, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter (phase 1): %v", err) + } + for i := 0; i < 10; i++ { + r := &Record{Opcode: OpcodePut, Key: []byte(fmt.Sprintf("k%d", i)), Value: []byte("v")} + if err := w1.Append(r); err != nil { + t.Fatalf("Append (phase 1): %v", err) + } + } + if err := w1.Close(); err != nil { + t.Fatalf("Close (phase 1): %v", err) + } + + // Get actual file size on disk. + info, err := os.Stat(segmentPath(dir, 1)) + if err != nil { + t.Fatalf("Stat: %v", err) + } + fileSize := info.Size() + if fileSize == 0 { + t.Fatal("segment file is unexpectedly empty") + } + + // Phase 2: Reopen the same segment (simulating post-recovery). + w2, err := NewLogWriter(dir, 1) + if err != nil { + t.Fatalf("NewLogWriter (phase 2): %v", err) + } + defer w2.Close() + + // The writer must account for the preexisting data. + if w2.currentSizeBytes != fileSize { + t.Errorf("currentSizeBytes = %d, want %d (preexisting file size)", + w2.currentSizeBytes, fileSize) + } +} + // TestAppend_ConcurrentWrites_NoDataRace validates concurrent WAL writes do not cause data races. func TestAppend_ConcurrentWrites_NoDataRace(t *testing.T) { dir := t.TempDir()