diff --git a/internal/sql/lexer/.gitkeep b/internal/sql/lexer/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/internal/sql/lexer/errors.go b/internal/sql/lexer/errors.go new file mode 100644 index 0000000..645bfd2 --- /dev/null +++ b/internal/sql/lexer/errors.go @@ -0,0 +1,38 @@ +package lexer + +import ( + "errors" + "fmt" +) + +// Sentinel errors — compare with errors.Is, never inspect the message string. +var ( + ErrUnexpectedChar = errors.New("unexpected character") + ErrUnterminatedString = errors.New("unterminated string literal") + ErrUnterminatedComment = errors.New("unterminated block comment") +) + +// LexError wraps a sentinel with the source location and a message. +type LexError struct { + Err error + Line int + Col int + Msg string +} + +func (e *LexError) Error() string { + return fmt.Sprintf("%d:%d: %s: %s", e.Line, e.Col, e.Err.Error(), e.Msg) +} + +// Unwrap lets errors.Is / errors.As traverse to the sentinel. +func (e *LexError) Unwrap() error { return e.Err } + +// lexErr — Msg now carries only the specific detail, not a repetition of the sentinel. +func lexErr(sentinel error, line, col int, format string, args ...any) *LexError { + return &LexError{ + Err: sentinel, + Line: line, + Col: col, + Msg: fmt.Sprintf(format, args...), + } +} diff --git a/internal/sql/lexer/keywords.go b/internal/sql/lexer/keywords.go new file mode 100644 index 0000000..8cfb371 --- /dev/null +++ b/internal/sql/lexer/keywords.go @@ -0,0 +1,98 @@ +package lexer + +import "strings" + +// keywords maps the canonical (upper-case) spelling of every reserved word to +// its TokenType. The lookup is always done on the upper-cased form of whatever +// the source contained, giving the grammar its case-insensitive keyword +// semantics while leaving Token.Literal in its original casing. +var keywords = map[string]TokenType{ + // DDL / database + "CREATE": TOKEN_CREATE, + "DATABASE": TOKEN_DATABASE, + "USE": TOKEN_USE, + "DROP": TOKEN_DROP, + "IF": TOKEN_IF, + "EXISTS": TOKEN_EXISTS, + + // Table DDL + "TABLE": TOKEN_TABLE, + "ALTER": TOKEN_ALTER, + "ADD": TOKEN_ADD, + "COLUMN": TOKEN_COLUMN, + "MODIFY": TOKEN_MODIFY, + "RENAME": TOKEN_RENAME, + "TO": TOKEN_TO, + + // DML + "SELECT": TOKEN_SELECT, + "DISTINCT": TOKEN_DISTINCT, + "ALL": TOKEN_ALL, + "FROM": TOKEN_FROM, + "WHERE": TOKEN_WHERE, + "AS": TOKEN_AS, + "INSERT": TOKEN_INSERT, + "INTO": TOKEN_INTO, + "VALUES": TOKEN_VALUES, + "UPDATE": TOKEN_UPDATE, + "SET": TOKEN_SET, + "DELETE": TOKEN_DELETE, + + // JOIN + "JOIN": TOKEN_JOIN, + "INNER": TOKEN_INNER, + "LEFT": TOKEN_LEFT, + "RIGHT": TOKEN_RIGHT, + "FULL": TOKEN_FULL, + "OUTER": TOKEN_OUTER, + "CROSS": TOKEN_CROSS, + "ON": TOKEN_ON, + + // Clauses + "GROUP": TOKEN_GROUP, + "BY": TOKEN_BY, + "HAVING": TOKEN_HAVING, + "ORDER": TOKEN_ORDER, + "ASC": TOKEN_ASC, + "DESC": TOKEN_DESC, + "LIMIT": TOKEN_LIMIT, + "OFFSET": TOKEN_OFFSET, + + // Constraints + "PRIMARY": TOKEN_PRIMARY, + "KEY": TOKEN_KEY, + "NOT": TOKEN_NOT, + "NULL": TOKEN_NULL, + "DEFAULT": TOKEN_DEFAULT, + "UNIQUE": TOKEN_UNIQUE, + "REFERENCES": TOKEN_REFERENCES, + + // Logical / predicates + "AND": TOKEN_AND, + "OR": TOKEN_OR, + "TRUE": TOKEN_TRUE, + "FALSE": TOKEN_FALSE, + "LIKE": TOKEN_LIKE, + "IS": TOKEN_IS, + "IN": TOKEN_IN, + "BETWEEN": TOKEN_BETWEEN, + + // Data types + "INT": TOKEN_INT, + "BIGINT": TOKEN_BIGINT, + "VARCHAR": TOKEN_VARCHAR, + "BOOLEAN": TOKEN_BOOLEAN, + "TEXT": TOKEN_TEXT, + "TIMESTAMP": TOKEN_TIMESTAMP, +} + +// lookupIdent returns the keyword TokenType for s if it is a reserved word, +// or TOKEN_IDENT if it is a plain user-defined name. +// The comparison is case-insensitive: "select", "SELECT", and "SeLeCt" all +// resolve to TOKEN_SELECT. +func lookupIdent(s string) TokenType { + if tt, ok := keywords[strings.ToUpper(s)]; ok { + return tt + } + return TOKEN_IDENT +} diff --git a/internal/sql/lexer/lexer.go b/internal/sql/lexer/lexer.go new file mode 100644 index 0000000..ab0db42 --- /dev/null +++ b/internal/sql/lexer/lexer.go @@ -0,0 +1,282 @@ +package lexer + +// Position represents a position in the source input. +// It tracks three values: +// +// - Index: absolute character offset from the start of the entire input (0-based). +// It counts every character (including newlines) and never resets. +// - Line: the current line number (1-based). Increments only when a '\n' is encountered. +// - Column: the position within the current line (1-based). Resets to 1 on every new line. +type position struct { + index int + line int + column int +} + +// Lexer tokenizes SQL source text into a stream of Tokens +type Lexer struct { + src []rune // full input as runes + pos position // current read cursor(line, col, index) +} + +// NewLexer creates a Lexer fro the given SQL source string +func NewLexer(src string) *Lexer { + return &Lexer{ + src: []rune(src), + pos: position{0, 1, 1}, + } +} + +func (l *Lexer) peek() rune { + if l.pos.index >= len(l.src) { + return 0 + } + return l.src[l.pos.index] +} + +func (l *Lexer) peekNext() rune { + if l.pos.index+1 >= len(l.src) { + return 0 + } + return l.src[l.pos.index+1] +} + +func (l *Lexer) advance() rune { + if l.pos.index >= len(l.src) { + return 0 + } + ch := l.src[l.pos.index] + l.pos.index++ + if ch == '\n' { + l.pos.line++ + l.pos.column = 1 + } else { + l.pos.column++ + } + return ch +} + +// skipLineComment discards everything from the current position to end-of-line. +// Precondition: the two leading '-' characters have already been consumed. +func (l *Lexer) skipLineComment() { + for l.pos.index < len(l.src) && l.src[l.pos.index] != '\n' { + l.advance() + } +} + +// skipBlockComment discards everything up to and including the closing */. +// Precondition: the opening /* has already been consumed. +func (l *Lexer) skipBlockComment(openLine, openCol int) error { + for l.pos.index < len(l.src) { + if l.peek() == '*' && l.peekNext() == '/' { + l.advance() // * + l.advance() // / + return nil + } + l.advance() + } + // End of input without finding */ + return lexErr(ErrUnterminatedComment, l.pos.line, l.pos.column, + "expected '*/' to close '/*' opened at %d:%d", openLine, openCol) +} + +// skipWhitespaceAndComments returns an error only for an unterminated block comment. +// All other skipped content (whitespace, line comments) is infallible. +func (l *Lexer) skipWhitespaceAndComments() error { + for l.pos.index < len(l.src) { + ch := l.src[l.pos.index] + switch { + case ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n': + l.advance() + + case ch == '-' && l.peekNext() == '-': + l.advance() + l.advance() + l.skipLineComment() + + case ch == '/' && l.peekNext() == '*': + openLine, openCol := l.pos.line, l.pos.column + l.advance() + l.advance() + if err := l.skipBlockComment(openLine, openCol); err != nil { + return err + } + + default: + return nil + } + } + return nil +} + +// makeToken is a convenience to build a Token with the given fields. +func (l *Lexer) makeToken(typ TokenType, lit string, line, col int) Token { + return Token{Type: typ, Literal: lit, Line: line, Col: col} +} + +// scanIdentifier reads a keyword or user identifier. +// Precondition: peek() is a letter. +func (l *Lexer) scanIdentifier() Token { + startLine, startCol := l.pos.line, l.pos.column + start := l.pos.index + for l.pos.index < len(l.src) { + ch := l.src[l.pos.index] + if isIdentPart(ch) { + l.advance() + } else { + break + } + } + lit := string(l.src[start:l.pos.index]) + typ := lookupIdent(lit) // keyword or TOKEN_IDENT + return l.makeToken(typ, lit, startLine, startCol) +} + +func isIdentStart(ch rune) bool { return isLetter(ch) || ch == '_' } +func isIdentPart(ch rune) bool { return isLetter(ch) || isDigit(ch) || ch == '_' } + +func isLetter(ch rune) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') } +func isDigit(ch rune) bool { return ch >= '0' && ch <= '9' } + +func (l *Lexer) scanNumber() Token { + startLine, startCol := l.pos.line, l.pos.column + start := l.pos.index + isFloat := false + + // Leading '.' case + if l.peek() == '.' { + isFloat = true + l.advance() + } + + // Digits consumption + for l.pos.index < len(l.src) && isDigit((l.src[l.pos.index])) { + l.advance() + } + + // Decimal check + if !isFloat && l.peek() == '.' { + nextCh := l.peekNext() + if nextCh == 0 || isDigit(nextCh) || (!isLetter(nextCh) && nextCh != '_') { + isFloat = true + l.advance() + for l.pos.index < len(l.src) && isDigit((l.src[l.pos.index])) { + l.advance() + } + } + } + + lit := string(l.src[start:l.pos.index]) + if isFloat { + return l.makeToken(TOKEN_FLOAT, lit, startLine, startCol) + } + return l.makeToken(TOKEN_INTEGER, lit, startLine, startCol) +} + +func (l *Lexer) scanString() (Token, error) { + startLine, startCol := l.pos.line, l.pos.column + l.advance() // consume opening ' + + var buf []rune + for { + if l.pos.index >= len(l.src) { + return l.makeToken(TOKEN_ILLEGAL, string(buf), startLine, startCol), + lexErr(ErrUnterminatedString, l.pos.line, l.pos.column, + "expected closing ' (string opened at %d:%d)", startLine, startCol) + } + + ch := l.advance() + if ch == '\'' { + if l.peek() == '\'' { // '' is the SQL escape for a literal single-quote + l.advance() + buf = append(buf, '\'') + } else { + break // normal close + } + } else { + buf = append(buf, ch) + } + } + + return l.makeToken(TOKEN_STRING, string(buf), startLine, startCol), nil +} + +func (l *Lexer) NextToken() (Token, error) { + if err := l.skipWhitespaceAndComments(); err != nil { + return l.makeToken(TOKEN_EOF, "", l.pos.line, l.pos.column), err + } + + if l.pos.index >= len(l.src) { + return l.makeToken(TOKEN_EOF, "", l.pos.line, l.pos.column), nil + } + + startLine, startCol := l.pos.line, l.pos.column + ch := l.peek() + + if isIdentStart(ch) { + return l.scanIdentifier(), nil + } + if isDigit(ch) { + return l.scanNumber(), nil + } + if ch == '.' { + if next := l.peekNext(); next != 0 && isDigit(next) { + return l.scanNumber(), nil + } + l.advance() + return l.makeToken(TOKEN_DOT, ".", startLine, startCol), nil + } + if ch == '\'' { + return l.scanString() + } + + l.advance() + switch ch { + case '(': + return l.makeToken(TOKEN_LPAREN, "(", startLine, startCol), nil + case ')': + return l.makeToken(TOKEN_RPAREN, ")", startLine, startCol), nil + case ',': + return l.makeToken(TOKEN_COMMA, ",", startLine, startCol), nil + case ';': + return l.makeToken(TOKEN_SEMICOLON, ";", startLine, startCol), nil + case '+': + return l.makeToken(TOKEN_PLUS, "+", startLine, startCol), nil + case '-': + return l.makeToken(TOKEN_MINUS, "-", startLine, startCol), nil + case '*': + return l.makeToken(TOKEN_STAR, "*", startLine, startCol), nil + case '/': + return l.makeToken(TOKEN_SLASH, "/", startLine, startCol), nil + case '%': + return l.makeToken(TOKEN_PERCENT, "%", startLine, startCol), nil + case '=': + return l.makeToken(TOKEN_EQ, "=", startLine, startCol), nil + // ── Multi-character operators ── + case '<': + if l.peek() == '=' { + l.advance() + return l.makeToken(TOKEN_LTE, "<=", startLine, startCol), nil + } + if l.peek() == '>' { + l.advance() + return l.makeToken(TOKEN_NEQ, "<>", startLine, startCol), nil + } + return l.makeToken(TOKEN_LT, "<", startLine, startCol), nil + case '>': + if l.peek() == '=' { + l.advance() + return l.makeToken(TOKEN_GTE, ">=", startLine, startCol), nil + } + return l.makeToken(TOKEN_GT, ">", startLine, startCol), nil + case '!': + if l.peek() == '=' { + l.advance() + return l.makeToken(TOKEN_NEQ, "!=", startLine, startCol), nil + } + return l.makeToken(TOKEN_ILLEGAL, "!", startLine, startCol), lexErr(ErrUnexpectedChar, startLine, startCol, "'!'; did you mean '!='?") + + default: + return l.makeToken(TOKEN_ILLEGAL, string(ch), startLine, startCol), lexErr(ErrUnexpectedChar, startLine, startCol, "%q", ch) + } +} diff --git a/internal/sql/lexer/lexer_test.go b/internal/sql/lexer/lexer_test.go new file mode 100644 index 0000000..8ef2b4d --- /dev/null +++ b/internal/sql/lexer/lexer_test.go @@ -0,0 +1,1292 @@ +package lexer + +import ( + "errors" + "fmt" + "testing" +) + +// ---------- helpers ---------------------------------------------------------- + +// tok is a compact constructor for expected Token values in table-driven tests. +func tok(typ TokenType, lit string, line, col int) Token { + return Token{Type: typ, Literal: lit, Line: line, Col: col} +} + +// collectAll drives the lexer to exhaustion and returns every token it emits +// (including the final EOF). It fails the test on the first error. +func collectAll(t *testing.T, input string) []Token { + t.Helper() + l := NewLexer(input) + var tokens []Token + for { + token, err := l.NextToken() + if err != nil { + t.Fatalf("unexpected error at %d:%d: %v", token.Line, token.Col, err) + } + tokens = append(tokens, token) + if token.Type == TOKEN_EOF { + break + } + } + return tokens +} + +// requireTokens asserts the full token stream for a given input, including the +// trailing EOF. +func requireTokens(t *testing.T, input string, want []Token) { + t.Helper() + got := collectAll(t, input) + if len(got) != len(want) { + t.Fatalf("token count mismatch: got %d, want %d\ngot: %v\nwant: %v", + len(got), len(want), got, want) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("token[%d]: got %v, want %v", i, got[i], want[i]) + } + } +} + +// requireError asserts that lexing produces a specific sentinel error and an +// ILLEGAL or EOF token at the expected position. +func requireError(t *testing.T, input string, sentinel error) { + t.Helper() + l := NewLexer(input) + for { + _, err := l.NextToken() + if err != nil { + if !errors.Is(err, sentinel) { + t.Fatalf("expected error wrapping %v, got %v", sentinel, err) + } + return + } + } +} + +// ---------- EOF & empty input ------------------------------------------------ + +func TestNextToken_EmptyInput(t *testing.T) { + requireTokens(t, "", []Token{ + tok(TOKEN_EOF, "", 1, 1), + }) +} + +func TestNextToken_OnlyWhitespace(t *testing.T) { + requireTokens(t, " \t \r\n \n ", []Token{ + tok(TOKEN_EOF, "", 3, 3), + }) +} + +func TestNextToken_RepeatedEOF(t *testing.T) { + l := NewLexer("") + for i := 0; i < 5; i++ { + token, err := l.NextToken() + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + if token.Type != TOKEN_EOF { + t.Fatalf("iteration %d: expected EOF, got %v", i, token) + } + } +} + +// ---------- Single-character punctuation ------------------------------------- + +func TestNextToken_Punctuation(t *testing.T) { + tests := []struct { + input string + want Token + }{ + {"(", tok(TOKEN_LPAREN, "(", 1, 1)}, + {")", tok(TOKEN_RPAREN, ")", 1, 1)}, + {",", tok(TOKEN_COMMA, ",", 1, 1)}, + {".", tok(TOKEN_DOT, ".", 1, 1)}, + {";", tok(TOKEN_SEMICOLON, ";", 1, 1)}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + requireTokens(t, tc.input, []Token{ + tc.want, + tok(TOKEN_EOF, "", 1, 2), + }) + }) + } +} + +// ---------- Arithmetic operators --------------------------------------------- + +func TestNextToken_ArithmeticOperators(t *testing.T) { + tests := []struct { + input string + want Token + }{ + {"+", tok(TOKEN_PLUS, "+", 1, 1)}, + {"-", tok(TOKEN_MINUS, "-", 1, 1)}, + {"*", tok(TOKEN_STAR, "*", 1, 1)}, + {"/", tok(TOKEN_SLASH, "/", 1, 1)}, + {"%", tok(TOKEN_PERCENT, "%", 1, 1)}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + requireTokens(t, tc.input, []Token{ + tc.want, + tok(TOKEN_EOF, "", 1, 2), + }) + }) + } +} + +// ---------- Comparison operators (single & multi-char) ----------------------- + +func TestNextToken_ComparisonOperators(t *testing.T) { + tests := []struct { + name string + input string + want []Token + }{ + {"EQ", "=", []Token{ + tok(TOKEN_EQ, "=", 1, 1), + tok(TOKEN_EOF, "", 1, 2), + }}, + {"LT", "<", []Token{ + tok(TOKEN_LT, "<", 1, 1), + tok(TOKEN_EOF, "", 1, 2), + }}, + {"GT", ">", []Token{ + tok(TOKEN_GT, ">", 1, 1), + tok(TOKEN_EOF, "", 1, 2), + }}, + {"LTE", "<=", []Token{ + tok(TOKEN_LTE, "<=", 1, 1), + tok(TOKEN_EOF, "", 1, 3), + }}, + {"GTE", ">=", []Token{ + tok(TOKEN_GTE, ">=", 1, 1), + tok(TOKEN_EOF, "", 1, 3), + }}, + {"NEQ_bang", "!=", []Token{ + tok(TOKEN_NEQ, "!=", 1, 1), + tok(TOKEN_EOF, "", 1, 3), + }}, + {"NEQ_diamond", "<>", []Token{ + tok(TOKEN_NEQ, "<>", 1, 1), + tok(TOKEN_EOF, "", 1, 3), + }}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + requireTokens(t, tc.input, tc.want) + }) + } +} + +// ---------- Lone bang (!) is ILLEGAL ----------------------------------------- + +func TestNextToken_LoneBang_IsIllegal(t *testing.T) { + l := NewLexer("!") + token, err := l.NextToken() + if err == nil { + t.Fatal("expected error for lone '!'") + } + if !errors.Is(err, ErrUnexpectedChar) { + t.Fatalf("expected ErrUnexpectedChar, got %v", err) + } + if token.Type != TOKEN_ILLEGAL { + t.Fatalf("expected TOKEN_ILLEGAL, got %v", token.Type) + } + if token.Literal != "!" { + t.Fatalf("expected literal '!', got %q", token.Literal) + } +} + +// ---------- Integer literals ------------------------------------------------- + +func TestNextToken_Integers(t *testing.T) { + tests := []struct { + input string + want Token + }{ + {"0", tok(TOKEN_INTEGER, "0", 1, 1)}, + {"1", tok(TOKEN_INTEGER, "1", 1, 1)}, + {"42", tok(TOKEN_INTEGER, "42", 1, 1)}, + {"999999", tok(TOKEN_INTEGER, "999999", 1, 1)}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + requireTokens(t, tc.input, []Token{ + tc.want, + tok(TOKEN_EOF, "", 1, len(tc.input)+1), + }) + }) + } +} + +// ---------- Float literals --------------------------------------------------- + +func TestNextToken_Floats(t *testing.T) { + tests := []struct { + name string + input string + want Token + }{ + {"simple", "3.14", tok(TOKEN_FLOAT, "3.14", 1, 1)}, + {"leading_dot", ".5", tok(TOKEN_FLOAT, ".5", 1, 1)}, + {"trailing_dot", "5.", tok(TOKEN_FLOAT, "5.", 1, 1)}, + {"zero_dot_zero", "0.0", tok(TOKEN_FLOAT, "0.0", 1, 1)}, + {"large", "12345.6789", tok(TOKEN_FLOAT, "12345.6789", 1, 1)}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + requireTokens(t, tc.input, []Token{ + tc.want, + tok(TOKEN_EOF, "", 1, len(tc.input)+1), + }) + }) + } +} + +// ---------- Dot-disambiguation (dot vs float) -------------------------------- + +func TestNextToken_DotVsFloat(t *testing.T) { + t.Run("dot_followed_by_identifier", func(t *testing.T) { + // "t.id" → IDENT "t", DOT ".", IDENT "id" + requireTokens(t, "t.id", []Token{ + tok(TOKEN_IDENT, "t", 1, 1), + tok(TOKEN_DOT, ".", 1, 2), + tok(TOKEN_IDENT, "id", 1, 3), + tok(TOKEN_EOF, "", 1, 5), + }) + }) + + t.Run("dot_followed_by_digit", func(t *testing.T) { + // ".5" → FLOAT ".5" + requireTokens(t, ".5", []Token{ + tok(TOKEN_FLOAT, ".5", 1, 1), + tok(TOKEN_EOF, "", 1, 3), + }) + }) + + t.Run("dot_alone", func(t *testing.T) { + requireTokens(t, ".", []Token{ + tok(TOKEN_DOT, ".", 1, 1), + tok(TOKEN_EOF, "", 1, 2), + }) + }) + + t.Run("number_dot_ident_is_int_then_dot_then_ident", func(t *testing.T) { + // "42.col" → INTEGER "42", DOT ".", IDENT "col" + requireTokens(t, "42.col", []Token{ + tok(TOKEN_INTEGER, "42", 1, 1), + tok(TOKEN_DOT, ".", 1, 3), + tok(TOKEN_IDENT, "col", 1, 4), + tok(TOKEN_EOF, "", 1, 7), + }) + }) + + t.Run("number_dot_underscore_is_int_then_dot_then_ident", func(t *testing.T) { + // "1._x" → INTEGER "1", DOT ".", IDENT "_x" + requireTokens(t, "1._x", []Token{ + tok(TOKEN_INTEGER, "1", 1, 1), + tok(TOKEN_DOT, ".", 1, 2), + tok(TOKEN_IDENT, "_x", 1, 3), + tok(TOKEN_EOF, "", 1, 5), + }) + }) +} + +// ---------- String literals -------------------------------------------------- + +func TestNextToken_Strings(t *testing.T) { + tests := []struct { + name string + input string + wantLit string + }{ + {"empty", "''", ""}, + {"simple", "'hello'", "hello"}, + {"with_spaces", "'hello world'", "hello world"}, + {"with_digits", "'abc123'", "abc123"}, + {"escaped_quote", "'it''s'", "it's"}, + {"double_escaped", "'a''''b'", "a''b"}, + {"only_escaped", "''''", "'"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + l := NewLexer(tc.input) + token, err := l.NextToken() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.Type != TOKEN_STRING { + t.Fatalf("expected TOKEN_STRING, got %v", token.Type) + } + if token.Literal != tc.wantLit { + t.Fatalf("literal: got %q, want %q", token.Literal, tc.wantLit) + } + if token.Line != 1 || token.Col != 1 { + t.Fatalf("position: got %d:%d, want 1:1", token.Line, token.Col) + } + }) + } +} + +func TestNextToken_UnterminatedString(t *testing.T) { + inputs := []string{ + "'hello", + "'", + "'unterminated", + "'it''s still open", + } + for _, input := range inputs { + t.Run(fmt.Sprintf("%q", input), func(t *testing.T) { + requireError(t, input, ErrUnterminatedString) + }) + } +} + +// ---------- Identifiers ------------------------------------------------------ + +func TestNextToken_Identifiers(t *testing.T) { + tests := []struct { + input string + want Token + }{ + {"foo", tok(TOKEN_IDENT, "foo", 1, 1)}, + {"Bar", tok(TOKEN_IDENT, "Bar", 1, 1)}, + {"_private", tok(TOKEN_IDENT, "_private", 1, 1)}, + {"col1", tok(TOKEN_IDENT, "col1", 1, 1)}, + {"_", tok(TOKEN_IDENT, "_", 1, 1)}, + {"a_b_c", tok(TOKEN_IDENT, "a_b_c", 1, 1)}, + {"CamelCase", tok(TOKEN_IDENT, "CamelCase", 1, 1)}, + {"x123abc", tok(TOKEN_IDENT, "x123abc", 1, 1)}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + requireTokens(t, tc.input, []Token{ + tc.want, + tok(TOKEN_EOF, "", 1, len(tc.input)+1), + }) + }) + } +} + +// ---------- Keywords (case-insensitive) -------------------------------------- + +func TestNextToken_AllKeywords(t *testing.T) { + // Exhaustive coverage of every keyword in the keywords map. + // Each entry tests UPPER, lower, and MiXeD casing. + allKeywords := []struct { + upper string + typ TokenType + }{ + {"CREATE", TOKEN_CREATE}, + {"DATABASE", TOKEN_DATABASE}, + {"USE", TOKEN_USE}, + {"DROP", TOKEN_DROP}, + {"IF", TOKEN_IF}, + {"EXISTS", TOKEN_EXISTS}, + {"TABLE", TOKEN_TABLE}, + {"ALTER", TOKEN_ALTER}, + {"ADD", TOKEN_ADD}, + {"COLUMN", TOKEN_COLUMN}, + {"MODIFY", TOKEN_MODIFY}, + {"RENAME", TOKEN_RENAME}, + {"TO", TOKEN_TO}, + {"SELECT", TOKEN_SELECT}, + {"DISTINCT", TOKEN_DISTINCT}, + {"ALL", TOKEN_ALL}, + {"FROM", TOKEN_FROM}, + {"WHERE", TOKEN_WHERE}, + {"AS", TOKEN_AS}, + {"INSERT", TOKEN_INSERT}, + {"INTO", TOKEN_INTO}, + {"VALUES", TOKEN_VALUES}, + {"UPDATE", TOKEN_UPDATE}, + {"SET", TOKEN_SET}, + {"DELETE", TOKEN_DELETE}, + {"JOIN", TOKEN_JOIN}, + {"INNER", TOKEN_INNER}, + {"LEFT", TOKEN_LEFT}, + {"RIGHT", TOKEN_RIGHT}, + {"FULL", TOKEN_FULL}, + {"OUTER", TOKEN_OUTER}, + {"CROSS", TOKEN_CROSS}, + {"ON", TOKEN_ON}, + {"GROUP", TOKEN_GROUP}, + {"BY", TOKEN_BY}, + {"HAVING", TOKEN_HAVING}, + {"ORDER", TOKEN_ORDER}, + {"ASC", TOKEN_ASC}, + {"DESC", TOKEN_DESC}, + {"LIMIT", TOKEN_LIMIT}, + {"OFFSET", TOKEN_OFFSET}, + {"PRIMARY", TOKEN_PRIMARY}, + {"KEY", TOKEN_KEY}, + {"NOT", TOKEN_NOT}, + {"NULL", TOKEN_NULL}, + {"DEFAULT", TOKEN_DEFAULT}, + {"UNIQUE", TOKEN_UNIQUE}, + {"REFERENCES", TOKEN_REFERENCES}, + {"AND", TOKEN_AND}, + {"OR", TOKEN_OR}, + {"TRUE", TOKEN_TRUE}, + {"FALSE", TOKEN_FALSE}, + {"LIKE", TOKEN_LIKE}, + {"IS", TOKEN_IS}, + {"IN", TOKEN_IN}, + {"BETWEEN", TOKEN_BETWEEN}, + {"INT", TOKEN_INT}, + {"BIGINT", TOKEN_BIGINT}, + {"VARCHAR", TOKEN_VARCHAR}, + {"BOOLEAN", TOKEN_BOOLEAN}, + {"TEXT", TOKEN_TEXT}, + {"TIMESTAMP", TOKEN_TIMESTAMP}, + } + for _, kw := range allKeywords { + t.Run(kw.upper, func(t *testing.T) { + // Upper case + tokens := collectAll(t, kw.upper) + if tokens[0].Type != kw.typ { + t.Errorf("UPPER %q: got type %v, want %v", kw.upper, tokens[0].Type, kw.typ) + } + if tokens[0].Literal != kw.upper { + t.Errorf("UPPER %q: literal got %q, want %q", kw.upper, tokens[0].Literal, kw.upper) + } + }) + } +} + +func TestNextToken_KeywordsCaseInsensitive(t *testing.T) { + // Verify that the literal preserves original casing while the type is correct. + cases := []struct { + input string + wantTyp TokenType + wantLit string + }{ + {"select", TOKEN_SELECT, "select"}, + {"SELECT", TOKEN_SELECT, "SELECT"}, + {"SeLeCt", TOKEN_SELECT, "SeLeCt"}, + {"from", TOKEN_FROM, "from"}, + {"From", TOKEN_FROM, "From"}, + {"insert", TOKEN_INSERT, "insert"}, + {"InSeRt", TOKEN_INSERT, "InSeRt"}, + {"null", TOKEN_NULL, "null"}, + {"Null", TOKEN_NULL, "Null"}, + {"true", TOKEN_TRUE, "true"}, + {"false", TOKEN_FALSE, "false"}, + {"FaLsE", TOKEN_FALSE, "FaLsE"}, + } + for _, tc := range cases { + t.Run(tc.input, func(t *testing.T) { + tokens := collectAll(t, tc.input) + if tokens[0].Type != tc.wantTyp { + t.Errorf("type: got %v, want %v", tokens[0].Type, tc.wantTyp) + } + if tokens[0].Literal != tc.wantLit { + t.Errorf("literal: got %q, want %q", tokens[0].Literal, tc.wantLit) + } + }) + } +} + +// ---------- Comments --------------------------------------------------------- + +func TestNextToken_LineComment(t *testing.T) { + tests := []struct { + name string + input string + want []Token + }{ + {"comment_at_end", "42 -- comment", []Token{ + tok(TOKEN_INTEGER, "42", 1, 1), + tok(TOKEN_EOF, "", 1, 14), + }}, + {"comment_only", "-- everything is a comment", []Token{ + tok(TOKEN_EOF, "", 1, 27), + }}, + {"comment_before_newline", "-- comment\n42", []Token{ + tok(TOKEN_INTEGER, "42", 2, 1), + tok(TOKEN_EOF, "", 2, 3), + }}, + {"multiple_line_comments", "-- first\n-- second\n42", []Token{ + tok(TOKEN_INTEGER, "42", 3, 1), + tok(TOKEN_EOF, "", 3, 3), + }}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + requireTokens(t, tc.input, tc.want) + }) + } +} + +func TestNextToken_BlockComment(t *testing.T) { + tests := []struct { + name string + input string + want []Token + }{ + {"inline", "/* comment */ 42", []Token{ + tok(TOKEN_INTEGER, "42", 1, 15), + tok(TOKEN_EOF, "", 1, 17), + }}, + {"multi_line", "/* line1\nline2 */ 42", []Token{ + tok(TOKEN_INTEGER, "42", 2, 10), + tok(TOKEN_EOF, "", 2, 12), + }}, + {"empty_block", "/**/ 42", []Token{ + tok(TOKEN_INTEGER, "42", 1, 6), + tok(TOKEN_EOF, "", 1, 8), + }}, + {"adjacent", "/*a*//*b*/ 42", []Token{ + tok(TOKEN_INTEGER, "42", 1, 12), + tok(TOKEN_EOF, "", 1, 14), + }}, + {"comment_only", "/* eof in comment? no, closed */", []Token{ + tok(TOKEN_EOF, "", 1, 33), + }}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + requireTokens(t, tc.input, tc.want) + }) + } +} + +func TestNextToken_UnterminatedBlockComment(t *testing.T) { + inputs := []string{ + "/* unclosed", + "/* also \n unclosed", + "/*", + } + for _, input := range inputs { + t.Run(fmt.Sprintf("%q", input), func(t *testing.T) { + requireError(t, input, ErrUnterminatedComment) + }) + } +} + +func TestNextToken_MixedComments(t *testing.T) { + input := "-- line\n/* block */ SELECT" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 2, 13), + tok(TOKEN_EOF, "", 2, 19), + }) +} + +// ---------- Illegal characters ----------------------------------------------- + +func TestNextToken_IllegalCharacters(t *testing.T) { + illegals := []string{"@", "#", "$", "^", "&", "~", "\\", "`", "?", "|"} + for _, ch := range illegals { + t.Run(ch, func(t *testing.T) { + l := NewLexer(ch) + token, err := l.NextToken() + if err == nil { + t.Fatal("expected error for illegal character") + } + if !errors.Is(err, ErrUnexpectedChar) { + t.Fatalf("expected ErrUnexpectedChar, got %v", err) + } + if token.Type != TOKEN_ILLEGAL { + t.Fatalf("expected TOKEN_ILLEGAL, got %v", token.Type) + } + if token.Literal != ch { + t.Fatalf("literal: got %q, want %q", token.Literal, ch) + } + }) + } +} + +// ---------- Line/column tracking --------------------------------------------- + +func TestNextToken_LineColTracking(t *testing.T) { + input := "SELECT\n *\nFROM t" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_STAR, "*", 2, 3), + tok(TOKEN_FROM, "FROM", 3, 1), + tok(TOKEN_IDENT, "t", 3, 6), + tok(TOKEN_EOF, "", 3, 7), + }) +} + +func TestNextToken_TabTracking(t *testing.T) { + // Tabs count as single column advances. + input := "\tSELECT" + tokens := collectAll(t, input) + if tokens[0].Col != 2 { + t.Errorf("expected column 2 after tab, got %d", tokens[0].Col) + } +} + +func TestNextToken_MultipleNewlines(t *testing.T) { + input := "\n\n\n42" + tokens := collectAll(t, input) + if tokens[0].Line != 4 || tokens[0].Col != 1 { + t.Errorf("expected 4:1, got %d:%d", tokens[0].Line, tokens[0].Col) + } +} + +func TestNextToken_CarriageReturnLineFeed(t *testing.T) { + // \r is treated as whitespace but doesn't increment line; only \n does. + input := "a\r\nb" + tokens := collectAll(t, input) + // 'a' at 1:1 + if tokens[0].Line != 1 || tokens[0].Col != 1 { + t.Errorf("'a' expected 1:1, got %d:%d", tokens[0].Line, tokens[0].Col) + } + // 'b' at 2:1 + if tokens[1].Line != 2 || tokens[1].Col != 1 { + t.Errorf("'b' expected 2:1, got %d:%d", tokens[1].Line, tokens[1].Col) + } +} + +// ---------- Whitespace sensitivity ------------------------------------------- + +func TestNextToken_MultipleSpaces(t *testing.T) { + input := "a b" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "a", 1, 1), + tok(TOKEN_IDENT, "b", 1, 7), + tok(TOKEN_EOF, "", 1, 8), + }) +} + +func TestNextToken_NoWhitespace(t *testing.T) { + input := "a+b" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "a", 1, 1), + tok(TOKEN_PLUS, "+", 1, 2), + tok(TOKEN_IDENT, "b", 1, 3), + tok(TOKEN_EOF, "", 1, 4), + }) +} + +// ---------- LexError structure ----------------------------------------------- + +func TestLexError_ErrorMessage(t *testing.T) { + e := lexErr(ErrUnexpectedChar, 5, 10, "'@'") + want := "5:10: unexpected character: '@'" + if e.Error() != want { + t.Errorf("got %q, want %q", e.Error(), want) + } +} + +func TestLexError_Unwrap(t *testing.T) { + e := lexErr(ErrUnterminatedString, 1, 1, "detail") + if !errors.Is(e, ErrUnterminatedString) { + t.Error("errors.Is should match sentinel") + } + var le *LexError + if !errors.As(e, &le) { + t.Error("errors.As should succeed for *LexError") + } + if le.Line != 1 || le.Col != 1 { + t.Errorf("position: got %d:%d, want 1:1", le.Line, le.Col) + } +} + +// ---------- Full SQL statements (integration) -------------------------------- + +func TestNextToken_SelectStatement(t *testing.T) { + input := "SELECT id, name FROM users WHERE age >= 18;" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_IDENT, "id", 1, 8), + tok(TOKEN_COMMA, ",", 1, 10), + tok(TOKEN_IDENT, "name", 1, 12), + tok(TOKEN_FROM, "FROM", 1, 17), + tok(TOKEN_IDENT, "users", 1, 22), + tok(TOKEN_WHERE, "WHERE", 1, 28), + tok(TOKEN_IDENT, "age", 1, 34), + tok(TOKEN_GTE, ">=", 1, 38), + tok(TOKEN_INTEGER, "18", 1, 41), + tok(TOKEN_SEMICOLON, ";", 1, 43), + tok(TOKEN_EOF, "", 1, 44), + }) +} + +func TestNextToken_InsertStatement(t *testing.T) { + input := "INSERT INTO users (name, age) VALUES ('Alice', 30);" + requireTokens(t, input, []Token{ + tok(TOKEN_INSERT, "INSERT", 1, 1), + tok(TOKEN_INTO, "INTO", 1, 8), + tok(TOKEN_IDENT, "users", 1, 13), + tok(TOKEN_LPAREN, "(", 1, 19), + tok(TOKEN_IDENT, "name", 1, 20), + tok(TOKEN_COMMA, ",", 1, 24), + tok(TOKEN_IDENT, "age", 1, 26), + tok(TOKEN_RPAREN, ")", 1, 29), + tok(TOKEN_VALUES, "VALUES", 1, 31), + tok(TOKEN_LPAREN, "(", 1, 38), + tok(TOKEN_STRING, "Alice", 1, 39), + tok(TOKEN_COMMA, ",", 1, 46), + tok(TOKEN_INTEGER, "30", 1, 48), + tok(TOKEN_RPAREN, ")", 1, 50), + tok(TOKEN_SEMICOLON, ";", 1, 51), + tok(TOKEN_EOF, "", 1, 52), + }) +} + +func TestNextToken_CreateTable(t *testing.T) { + input := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR NOT NULL, + active BOOLEAN DEFAULT TRUE +);` + requireTokens(t, input, []Token{ + tok(TOKEN_CREATE, "CREATE", 1, 1), + tok(TOKEN_TABLE, "TABLE", 1, 8), + tok(TOKEN_IDENT, "users", 1, 14), + tok(TOKEN_LPAREN, "(", 1, 20), + // line 2 + tok(TOKEN_IDENT, "id", 2, 5), + tok(TOKEN_INT, "INT", 2, 8), + tok(TOKEN_PRIMARY, "PRIMARY", 2, 12), + tok(TOKEN_KEY, "KEY", 2, 20), + tok(TOKEN_COMMA, ",", 2, 23), + // line 3 + tok(TOKEN_IDENT, "name", 3, 5), + tok(TOKEN_VARCHAR, "VARCHAR", 3, 10), + tok(TOKEN_NOT, "NOT", 3, 18), + tok(TOKEN_NULL, "NULL", 3, 22), + tok(TOKEN_COMMA, ",", 3, 26), + // line 4 + tok(TOKEN_IDENT, "active", 4, 5), + tok(TOKEN_BOOLEAN, "BOOLEAN", 4, 12), + tok(TOKEN_DEFAULT, "DEFAULT", 4, 20), + tok(TOKEN_TRUE, "TRUE", 4, 28), + // line 5 + tok(TOKEN_RPAREN, ")", 5, 1), + tok(TOKEN_SEMICOLON, ";", 5, 2), + tok(TOKEN_EOF, "", 5, 3), + }) +} + +func TestNextToken_UpdateStatement(t *testing.T) { + input := "UPDATE users SET name = 'Bob' WHERE id = 1;" + requireTokens(t, input, []Token{ + tok(TOKEN_UPDATE, "UPDATE", 1, 1), + tok(TOKEN_IDENT, "users", 1, 8), + tok(TOKEN_SET, "SET", 1, 14), + tok(TOKEN_IDENT, "name", 1, 18), + tok(TOKEN_EQ, "=", 1, 23), + tok(TOKEN_STRING, "Bob", 1, 25), + tok(TOKEN_WHERE, "WHERE", 1, 31), + tok(TOKEN_IDENT, "id", 1, 37), + tok(TOKEN_EQ, "=", 1, 40), + tok(TOKEN_INTEGER, "1", 1, 42), + tok(TOKEN_SEMICOLON, ";", 1, 43), + tok(TOKEN_EOF, "", 1, 44), + }) +} + +func TestNextToken_DeleteStatement(t *testing.T) { + input := "DELETE FROM users WHERE id = 1;" + requireTokens(t, input, []Token{ + tok(TOKEN_DELETE, "DELETE", 1, 1), + tok(TOKEN_FROM, "FROM", 1, 8), + tok(TOKEN_IDENT, "users", 1, 13), + tok(TOKEN_WHERE, "WHERE", 1, 19), + tok(TOKEN_IDENT, "id", 1, 25), + tok(TOKEN_EQ, "=", 1, 28), + tok(TOKEN_INTEGER, "1", 1, 30), + tok(TOKEN_SEMICOLON, ";", 1, 31), + tok(TOKEN_EOF, "", 1, 32), + }) +} + +func TestNextToken_JoinQuery(t *testing.T) { + input := "SELECT a.id FROM a INNER JOIN b ON a.id = b.a_id" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_IDENT, "a", 1, 8), + tok(TOKEN_DOT, ".", 1, 9), + tok(TOKEN_IDENT, "id", 1, 10), + tok(TOKEN_FROM, "FROM", 1, 13), + tok(TOKEN_IDENT, "a", 1, 18), + tok(TOKEN_INNER, "INNER", 1, 20), + tok(TOKEN_JOIN, "JOIN", 1, 26), + tok(TOKEN_IDENT, "b", 1, 31), + tok(TOKEN_ON, "ON", 1, 33), + tok(TOKEN_IDENT, "a", 1, 36), + tok(TOKEN_DOT, ".", 1, 37), + tok(TOKEN_IDENT, "id", 1, 38), + tok(TOKEN_EQ, "=", 1, 41), + tok(TOKEN_IDENT, "b", 1, 43), + tok(TOKEN_DOT, ".", 1, 44), + tok(TOKEN_IDENT, "a_id", 1, 45), + tok(TOKEN_EOF, "", 1, 49), + }) +} + +func TestNextToken_GroupByHavingOrderBy(t *testing.T) { + input := "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC LIMIT 10 OFFSET 5" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_IDENT, "dept", 1, 8), + tok(TOKEN_COMMA, ",", 1, 12), + tok(TOKEN_IDENT, "COUNT", 1, 14), + tok(TOKEN_LPAREN, "(", 1, 19), + tok(TOKEN_STAR, "*", 1, 20), + tok(TOKEN_RPAREN, ")", 1, 21), + tok(TOKEN_FROM, "FROM", 1, 23), + tok(TOKEN_IDENT, "emp", 1, 28), + tok(TOKEN_GROUP, "GROUP", 1, 32), + tok(TOKEN_BY, "BY", 1, 38), + tok(TOKEN_IDENT, "dept", 1, 41), + tok(TOKEN_HAVING, "HAVING", 1, 46), + tok(TOKEN_IDENT, "COUNT", 1, 53), + tok(TOKEN_LPAREN, "(", 1, 58), + tok(TOKEN_STAR, "*", 1, 59), + tok(TOKEN_RPAREN, ")", 1, 60), + tok(TOKEN_GT, ">", 1, 62), + tok(TOKEN_INTEGER, "5", 1, 64), + tok(TOKEN_ORDER, "ORDER", 1, 66), + tok(TOKEN_BY, "BY", 1, 72), + tok(TOKEN_IDENT, "dept", 1, 75), + tok(TOKEN_ASC, "ASC", 1, 80), + tok(TOKEN_LIMIT, "LIMIT", 1, 84), + tok(TOKEN_INTEGER, "10", 1, 90), + tok(TOKEN_OFFSET, "OFFSET", 1, 93), + tok(TOKEN_INTEGER, "5", 1, 100), + tok(TOKEN_EOF, "", 1, 101), + }) +} + +func TestNextToken_ComplexExpression(t *testing.T) { + input := "WHERE x BETWEEN 1 AND 10 AND name LIKE 'foo%' OR val IS NOT NULL AND id IN (1, 2, 3)" + requireTokens(t, input, []Token{ + tok(TOKEN_WHERE, "WHERE", 1, 1), + tok(TOKEN_IDENT, "x", 1, 7), + tok(TOKEN_BETWEEN, "BETWEEN", 1, 9), + tok(TOKEN_INTEGER, "1", 1, 17), + tok(TOKEN_AND, "AND", 1, 19), + tok(TOKEN_INTEGER, "10", 1, 23), + tok(TOKEN_AND, "AND", 1, 26), + tok(TOKEN_IDENT, "name", 1, 30), + tok(TOKEN_LIKE, "LIKE", 1, 35), + tok(TOKEN_STRING, "foo%", 1, 40), + tok(TOKEN_OR, "OR", 1, 47), + tok(TOKEN_IDENT, "val", 1, 50), + tok(TOKEN_IS, "IS", 1, 54), + tok(TOKEN_NOT, "NOT", 1, 57), + tok(TOKEN_NULL, "NULL", 1, 61), + tok(TOKEN_AND, "AND", 1, 66), + tok(TOKEN_IDENT, "id", 1, 70), + tok(TOKEN_IN, "IN", 1, 73), + tok(TOKEN_LPAREN, "(", 1, 76), + tok(TOKEN_INTEGER, "1", 1, 77), + tok(TOKEN_COMMA, ",", 1, 78), + tok(TOKEN_INTEGER, "2", 1, 80), + tok(TOKEN_COMMA, ",", 1, 81), + tok(TOKEN_INTEGER, "3", 1, 83), + tok(TOKEN_RPAREN, ")", 1, 84), + tok(TOKEN_EOF, "", 1, 85), + }) +} + +func TestNextToken_AlterTable(t *testing.T) { + input := "ALTER TABLE users ADD COLUMN email TEXT UNIQUE" + requireTokens(t, input, []Token{ + tok(TOKEN_ALTER, "ALTER", 1, 1), + tok(TOKEN_TABLE, "TABLE", 1, 7), + tok(TOKEN_IDENT, "users", 1, 13), + tok(TOKEN_ADD, "ADD", 1, 19), + tok(TOKEN_COLUMN, "COLUMN", 1, 23), + tok(TOKEN_IDENT, "email", 1, 30), + tok(TOKEN_TEXT, "TEXT", 1, 36), + tok(TOKEN_UNIQUE, "UNIQUE", 1, 41), + tok(TOKEN_EOF, "", 1, 47), + }) +} + +func TestNextToken_DropIfExists(t *testing.T) { + input := "DROP TABLE IF EXISTS users;" + requireTokens(t, input, []Token{ + tok(TOKEN_DROP, "DROP", 1, 1), + tok(TOKEN_TABLE, "TABLE", 1, 6), + tok(TOKEN_IF, "IF", 1, 12), + tok(TOKEN_EXISTS, "EXISTS", 1, 15), + tok(TOKEN_IDENT, "users", 1, 22), + tok(TOKEN_SEMICOLON, ";", 1, 27), + tok(TOKEN_EOF, "", 1, 28), + }) +} + +func TestNextToken_CreateDatabase(t *testing.T) { + input := "CREATE DATABASE mydb;" + requireTokens(t, input, []Token{ + tok(TOKEN_CREATE, "CREATE", 1, 1), + tok(TOKEN_DATABASE, "DATABASE", 1, 8), + tok(TOKEN_IDENT, "mydb", 1, 17), + tok(TOKEN_SEMICOLON, ";", 1, 21), + tok(TOKEN_EOF, "", 1, 22), + }) +} + +func TestNextToken_UseDatabase(t *testing.T) { + input := "USE mydb;" + requireTokens(t, input, []Token{ + tok(TOKEN_USE, "USE", 1, 1), + tok(TOKEN_IDENT, "mydb", 1, 5), + tok(TOKEN_SEMICOLON, ";", 1, 9), + tok(TOKEN_EOF, "", 1, 10), + }) +} + +func TestNextToken_RenameTable(t *testing.T) { + input := "ALTER TABLE old_name RENAME TO new_name;" + requireTokens(t, input, []Token{ + tok(TOKEN_ALTER, "ALTER", 1, 1), + tok(TOKEN_TABLE, "TABLE", 1, 7), + tok(TOKEN_IDENT, "old_name", 1, 13), + tok(TOKEN_RENAME, "RENAME", 1, 22), + tok(TOKEN_TO, "TO", 1, 29), + tok(TOKEN_IDENT, "new_name", 1, 32), + tok(TOKEN_SEMICOLON, ";", 1, 40), + tok(TOKEN_EOF, "", 1, 41), + }) +} + +func TestNextToken_SelectWithAlias(t *testing.T) { + input := "SELECT DISTINCT name AS n FROM users" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_DISTINCT, "DISTINCT", 1, 8), + tok(TOKEN_IDENT, "name", 1, 17), + tok(TOKEN_AS, "AS", 1, 22), + tok(TOKEN_IDENT, "n", 1, 25), + tok(TOKEN_FROM, "FROM", 1, 27), + tok(TOKEN_IDENT, "users", 1, 32), + tok(TOKEN_EOF, "", 1, 37), + }) +} + +func TestNextToken_SelectAllJoins(t *testing.T) { + input := "LEFT OUTER JOIN RIGHT OUTER JOIN FULL OUTER JOIN CROSS JOIN" + requireTokens(t, input, []Token{ + tok(TOKEN_LEFT, "LEFT", 1, 1), + tok(TOKEN_OUTER, "OUTER", 1, 6), + tok(TOKEN_JOIN, "JOIN", 1, 12), + tok(TOKEN_RIGHT, "RIGHT", 1, 17), + tok(TOKEN_OUTER, "OUTER", 1, 23), + tok(TOKEN_JOIN, "JOIN", 1, 29), + tok(TOKEN_FULL, "FULL", 1, 34), + tok(TOKEN_OUTER, "OUTER", 1, 39), + tok(TOKEN_JOIN, "JOIN", 1, 45), + tok(TOKEN_CROSS, "CROSS", 1, 50), + tok(TOKEN_JOIN, "JOIN", 1, 56), + tok(TOKEN_EOF, "", 1, 60), + }) +} + +func TestNextToken_ForeignKeyReference(t *testing.T) { + input := "user_id BIGINT REFERENCES users(id)" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "user_id", 1, 1), + tok(TOKEN_BIGINT, "BIGINT", 1, 9), + tok(TOKEN_REFERENCES, "REFERENCES", 1, 16), + tok(TOKEN_IDENT, "users", 1, 27), + tok(TOKEN_LPAREN, "(", 1, 32), + tok(TOKEN_IDENT, "id", 1, 33), + tok(TOKEN_RPAREN, ")", 1, 35), + tok(TOKEN_EOF, "", 1, 36), + }) +} + +func TestNextToken_TimestampColumn(t *testing.T) { + input := "created_at TIMESTAMP NOT NULL DEFAULT '2024-01-01'" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "created_at", 1, 1), + tok(TOKEN_TIMESTAMP, "TIMESTAMP", 1, 12), + tok(TOKEN_NOT, "NOT", 1, 22), + tok(TOKEN_NULL, "NULL", 1, 26), + tok(TOKEN_DEFAULT, "DEFAULT", 1, 31), + tok(TOKEN_STRING, "2024-01-01", 1, 39), + tok(TOKEN_EOF, "", 1, 51), + }) +} + +func TestNextToken_ArithmeticExpression(t *testing.T) { + input := "a + b - c * d / e % f" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "a", 1, 1), + tok(TOKEN_PLUS, "+", 1, 3), + tok(TOKEN_IDENT, "b", 1, 5), + tok(TOKEN_MINUS, "-", 1, 7), + tok(TOKEN_IDENT, "c", 1, 9), + tok(TOKEN_STAR, "*", 1, 11), + tok(TOKEN_IDENT, "d", 1, 13), + tok(TOKEN_SLASH, "/", 1, 15), + tok(TOKEN_IDENT, "e", 1, 17), + tok(TOKEN_PERCENT, "%", 1, 19), + tok(TOKEN_IDENT, "f", 1, 21), + tok(TOKEN_EOF, "", 1, 22), + }) +} + +func TestNextToken_SelectAll(t *testing.T) { + input := "SELECT ALL * FROM t" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_ALL, "ALL", 1, 8), + tok(TOKEN_STAR, "*", 1, 12), + tok(TOKEN_FROM, "FROM", 1, 14), + tok(TOKEN_IDENT, "t", 1, 19), + tok(TOKEN_EOF, "", 1, 20), + }) +} + +func TestNextToken_DescOrder(t *testing.T) { + input := "ORDER BY col DESC" + requireTokens(t, input, []Token{ + tok(TOKEN_ORDER, "ORDER", 1, 1), + tok(TOKEN_BY, "BY", 1, 7), + tok(TOKEN_IDENT, "col", 1, 10), + tok(TOKEN_DESC, "DESC", 1, 14), + tok(TOKEN_EOF, "", 1, 18), + }) +} + +// ---------- Edge cases ------------------------------------------------------- + +func TestNextToken_MinusVsLineComment(t *testing.T) { + // Single minus is TOKEN_MINUS; double minus is a line comment. + t.Run("single_minus", func(t *testing.T) { + requireTokens(t, "3 - 1", []Token{ + tok(TOKEN_INTEGER, "3", 1, 1), + tok(TOKEN_MINUS, "-", 1, 3), + tok(TOKEN_INTEGER, "1", 1, 5), + tok(TOKEN_EOF, "", 1, 6), + }) + }) + t.Run("double_minus_is_comment", func(t *testing.T) { + requireTokens(t, "3 -- 1", []Token{ + tok(TOKEN_INTEGER, "3", 1, 1), + tok(TOKEN_EOF, "", 1, 7), + }) + }) +} + +func TestNextToken_SlashVsBlockComment(t *testing.T) { + // Single slash is TOKEN_SLASH; /* starts a block comment. + t.Run("single_slash", func(t *testing.T) { + requireTokens(t, "3 / 1", []Token{ + tok(TOKEN_INTEGER, "3", 1, 1), + tok(TOKEN_SLASH, "/", 1, 3), + tok(TOKEN_INTEGER, "1", 1, 5), + tok(TOKEN_EOF, "", 1, 6), + }) + }) + t.Run("slash_star_is_comment", func(t *testing.T) { + requireTokens(t, "3 /* comment */ / 1", []Token{ + tok(TOKEN_INTEGER, "3", 1, 1), + tok(TOKEN_SLASH, "/", 1, 17), + tok(TOKEN_INTEGER, "1", 1, 19), + tok(TOKEN_EOF, "", 1, 20), + }) + }) +} + +func TestNextToken_LessThanAmbiguity(t *testing.T) { + // < alone, <=, <> + t.Run("lt_followed_by_space", func(t *testing.T) { + requireTokens(t, "a < b", []Token{ + tok(TOKEN_IDENT, "a", 1, 1), + tok(TOKEN_LT, "<", 1, 3), + tok(TOKEN_IDENT, "b", 1, 5), + tok(TOKEN_EOF, "", 1, 6), + }) + }) + t.Run("lt_followed_by_eq", func(t *testing.T) { + requireTokens(t, "a<=b", []Token{ + tok(TOKEN_IDENT, "a", 1, 1), + tok(TOKEN_LTE, "<=", 1, 2), + tok(TOKEN_IDENT, "b", 1, 4), + tok(TOKEN_EOF, "", 1, 5), + }) + }) + t.Run("lt_followed_by_gt", func(t *testing.T) { + requireTokens(t, "a<>b", []Token{ + tok(TOKEN_IDENT, "a", 1, 1), + tok(TOKEN_NEQ, "<>", 1, 2), + tok(TOKEN_IDENT, "b", 1, 4), + tok(TOKEN_EOF, "", 1, 5), + }) + }) +} + +func TestNextToken_ConsecutiveOperators(t *testing.T) { + input := ">=<=" + requireTokens(t, input, []Token{ + tok(TOKEN_GTE, ">=", 1, 1), + tok(TOKEN_LTE, "<=", 1, 3), + tok(TOKEN_EOF, "", 1, 5), + }) +} + +func TestNextToken_StringInContext(t *testing.T) { + input := "WHERE name = 'O''Brien'" + requireTokens(t, input, []Token{ + tok(TOKEN_WHERE, "WHERE", 1, 1), + tok(TOKEN_IDENT, "name", 1, 7), + tok(TOKEN_EQ, "=", 1, 12), + tok(TOKEN_STRING, "O'Brien", 1, 14), + tok(TOKEN_EOF, "", 1, 24), + }) +} + +func TestNextToken_FloatInExpression(t *testing.T) { + input := "price * 1.08 + .5" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "price", 1, 1), + tok(TOKEN_STAR, "*", 1, 7), + tok(TOKEN_FLOAT, "1.08", 1, 9), + tok(TOKEN_PLUS, "+", 1, 14), + tok(TOKEN_FLOAT, ".5", 1, 16), + tok(TOKEN_EOF, "", 1, 18), + }) +} + +func TestNextToken_IdentStartingWithUnderscore(t *testing.T) { + input := "_foo _123 __" + requireTokens(t, input, []Token{ + tok(TOKEN_IDENT, "_foo", 1, 1), + tok(TOKEN_IDENT, "_123", 1, 6), + tok(TOKEN_IDENT, "__", 1, 11), + tok(TOKEN_EOF, "", 1, 13), + }) +} + +func TestNextToken_KeywordAsPrefix(t *testing.T) { + // "selection" should be IDENT, not SELECT + "ion" + requireTokens(t, "selection", []Token{ + tok(TOKEN_IDENT, "selection", 1, 1), + tok(TOKEN_EOF, "", 1, 10), + }) +} + +func TestNextToken_MultiLineString(t *testing.T) { + // Strings can span newlines. + input := "'line1\nline2'" + l := NewLexer(input) + token, err := l.NextToken() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.Type != TOKEN_STRING { + t.Fatalf("expected TOKEN_STRING, got %v", token.Type) + } + if token.Literal != "line1\nline2" { + t.Fatalf("literal: got %q, want %q", token.Literal, "line1\nline2") + } +} + +func TestNextToken_NumberFollowedByDotFollowedByNumber(t *testing.T) { + // "1.2.3" → FLOAT "1.2", then ".3" starts a leading-dot float. + requireTokens(t, "1.2.3", []Token{ + tok(TOKEN_FLOAT, "1.2", 1, 1), + tok(TOKEN_FLOAT, ".3", 1, 4), + tok(TOKEN_EOF, "", 1, 6), + }) +} + +func TestNextToken_ModifyKeyword(t *testing.T) { + input := "ALTER TABLE t MODIFY COLUMN c INT;" + requireTokens(t, input, []Token{ + tok(TOKEN_ALTER, "ALTER", 1, 1), + tok(TOKEN_TABLE, "TABLE", 1, 7), + tok(TOKEN_IDENT, "t", 1, 13), + tok(TOKEN_MODIFY, "MODIFY", 1, 15), + tok(TOKEN_COLUMN, "COLUMN", 1, 22), + tok(TOKEN_IDENT, "c", 1, 29), + tok(TOKEN_INT, "INT", 1, 31), + tok(TOKEN_SEMICOLON, ";", 1, 34), + tok(TOKEN_EOF, "", 1, 35), + }) +} + +func TestNextToken_SelectWithComments(t *testing.T) { + input := `SELECT -- column list + id, /* primary key */ + name +FROM users;` + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_IDENT, "id", 2, 5), + tok(TOKEN_COMMA, ",", 2, 7), + tok(TOKEN_IDENT, "name", 3, 5), + tok(TOKEN_FROM, "FROM", 4, 1), + tok(TOKEN_IDENT, "users", 4, 6), + tok(TOKEN_SEMICOLON, ";", 4, 11), + tok(TOKEN_EOF, "", 4, 12), + }) +} + +func TestNextToken_OperatorsWithNoSpaces(t *testing.T) { + input := "(a+b)*(c-d)" + requireTokens(t, input, []Token{ + tok(TOKEN_LPAREN, "(", 1, 1), + tok(TOKEN_IDENT, "a", 1, 2), + tok(TOKEN_PLUS, "+", 1, 3), + tok(TOKEN_IDENT, "b", 1, 4), + tok(TOKEN_RPAREN, ")", 1, 5), + tok(TOKEN_STAR, "*", 1, 6), + tok(TOKEN_LPAREN, "(", 1, 7), + tok(TOKEN_IDENT, "c", 1, 8), + tok(TOKEN_MINUS, "-", 1, 9), + tok(TOKEN_IDENT, "d", 1, 10), + tok(TOKEN_RPAREN, ")", 1, 11), + tok(TOKEN_EOF, "", 1, 12), + }) +} + +func TestNextToken_NumberAtEndOfInput(t *testing.T) { + // Number followed immediately by EOF, with trailing dot. + requireTokens(t, "42.", []Token{ + tok(TOKEN_FLOAT, "42.", 1, 1), + tok(TOKEN_EOF, "", 1, 4), + }) +} + +func TestNextToken_SelectStar(t *testing.T) { + input := "SELECT * FROM t;" + requireTokens(t, input, []Token{ + tok(TOKEN_SELECT, "SELECT", 1, 1), + tok(TOKEN_STAR, "*", 1, 8), + tok(TOKEN_FROM, "FROM", 1, 10), + tok(TOKEN_IDENT, "t", 1, 15), + tok(TOKEN_SEMICOLON, ";", 1, 16), + tok(TOKEN_EOF, "", 1, 17), + }) +} + +func TestNextToken_NegativeNumberContext(t *testing.T) { + // Minus is a separate token; the parser handles negation semantically. + requireTokens(t, "-42", []Token{ + tok(TOKEN_MINUS, "-", 1, 1), + tok(TOKEN_INTEGER, "42", 1, 2), + tok(TOKEN_EOF, "", 1, 4), + }) +} + +func TestNextToken_ErrorRecovery(t *testing.T) { + // After hitting an illegal character, the lexer should still be able to + // produce subsequent tokens. + l := NewLexer("@ SELECT") + token, err := l.NextToken() + if err == nil || token.Type != TOKEN_ILLEGAL { + t.Fatalf("expected ILLEGAL token with error, got %v, err=%v", token, err) + } + // The next call should produce SELECT. + token, err = l.NextToken() + if err != nil { + t.Fatalf("unexpected error after recovery: %v", err) + } + if token.Type != TOKEN_SELECT { + t.Fatalf("expected SELECT after recovery, got %v", token.Type) + } +} diff --git a/internal/sql/lexer/lookahead.go b/internal/sql/lexer/lookahead.go new file mode 100644 index 0000000..320662e --- /dev/null +++ b/internal/sql/lexer/lookahead.go @@ -0,0 +1,65 @@ +package lexer + +// LookaheadIterator provides single-element lookahead over an arbitrary stream. +// Used by the parser to inspect the next token without consuming it. +// Not used by the lexer itself. +// +// This type is not safe for concurrent use. +type LookaheadIterator[T any] struct { + nextFn func() T // produces the next element on demand + peeked *T // buffered lookahead element, nil if not yet peeked + count int // number of elements consumed via Next() +} + +// NewLookaheadIterator creates a LookaheadIterator backed by the given function. +// nextFn is called each time a new element is needed. +func NewLookaheadIterator[T any](nextFn func() T) *LookaheadIterator[T] { + return &LookaheadIterator[T]{nextFn: nextFn} +} + +// Peek returns the next element without consuming it. +// Successive calls return the same element until Next() is called. +func (p *LookaheadIterator[T]) Peek() T { + if p.peeked != nil { + return *p.peeked + } + v := p.nextFn() + p.peeked = &v + return v +} + +// Next consumes and returns the next element. +func (p *LookaheadIterator[T]) Next() T { + var v T + if p.peeked != nil { + v = *p.peeked + p.peeked = nil + } else { + v = p.nextFn() + } + p.count++ + return v +} + +// ExpectNextValue consumes and returns the next element if it equals expected. +// Returns a pointer to the consumed element, or nil if it did not match. +func (p *LookaheadIterator[T]) ExpectNextValue(expected T, eq func(a, b T) bool) *T { + return p.ExpectNextMatches(func(v T) bool { return eq(v, expected) }) +} + +// ExpectNextMatches consumes and returns the next element if it satisfies the +// predicate. Returns a pointer to the consumed element, or nil otherwise. +// The predicate is called at most once. +func (p *LookaheadIterator[T]) ExpectNextMatches(predicate func(T) bool) *T { + v := p.Peek() + if predicate(v) { + p.Next() // consume + return &v + } + return nil +} + +// Count returns the number of elements consumed via Next() so far. +func (p *LookaheadIterator[T]) Count() int { + return p.count +} diff --git a/internal/sql/lexer/lookahead_test.go b/internal/sql/lexer/lookahead_test.go new file mode 100644 index 0000000..520b6c6 --- /dev/null +++ b/internal/sql/lexer/lookahead_test.go @@ -0,0 +1,442 @@ +package lexer + +import "testing" + +// ---------- LookaheadIterator ------------------------------------------------ + +func TestLookaheadIterator_BasicNextAndPeek(t *testing.T) { + seq := []int{10, 20, 30} + idx := 0 + iter := NewLookaheadIterator(func() int { + v := seq[idx] + idx++ + return v + }) + + // Peek should return first element without consuming. + if got := iter.Peek(); got != 10 { + t.Fatalf("Peek() = %d, want 10", got) + } + if got := iter.Peek(); got != 10 { + t.Fatalf("second Peek() = %d, want 10 (should be idempotent)", got) + } + if iter.Count() != 0 { + t.Fatalf("Count() = %d after Peek, want 0", iter.Count()) + } + + // Next should consume the peeked element. + if got := iter.Next(); got != 10 { + t.Fatalf("Next() = %d, want 10", got) + } + if iter.Count() != 1 { + t.Fatalf("Count() = %d after first Next, want 1", iter.Count()) + } + + // Next without prior Peek. + if got := iter.Next(); got != 20 { + t.Fatalf("Next() = %d, want 20", got) + } + if iter.Count() != 2 { + t.Fatalf("Count() = %d, want 2", iter.Count()) + } + + // Peek then Next. + if got := iter.Peek(); got != 30 { + t.Fatalf("Peek() = %d, want 30", got) + } + if got := iter.Next(); got != 30 { + t.Fatalf("Next() = %d, want 30", got) + } + if iter.Count() != 3 { + t.Fatalf("Count() = %d, want 3", iter.Count()) + } +} + +func TestLookaheadIterator_NextWithoutPeek(t *testing.T) { + calls := 0 + iter := NewLookaheadIterator(func() int { + calls++ + return calls + }) + + // Calling Next without Peek should call nextFn directly. + if got := iter.Next(); got != 1 { + t.Fatalf("Next() = %d, want 1", got) + } + if got := iter.Next(); got != 2 { + t.Fatalf("Next() = %d, want 2", got) + } + if calls != 2 { + t.Fatalf("nextFn called %d times, want 2", calls) + } +} + +func TestLookaheadIterator_PeekDoesNotCallNextFnTwice(t *testing.T) { + calls := 0 + iter := NewLookaheadIterator(func() string { + calls++ + return "hello" + }) + + _ = iter.Peek() + _ = iter.Peek() + _ = iter.Peek() + + if calls != 1 { + t.Fatalf("nextFn called %d times, want 1 (Peek should buffer)", calls) + } +} + +func TestLookaheadIterator_Count_StartsAtZero(t *testing.T) { + iter := NewLookaheadIterator(func() int { return 0 }) + if iter.Count() != 0 { + t.Fatalf("Count() = %d, want 0", iter.Count()) + } +} + +func TestLookaheadIterator_Count_IncrementedByNext(t *testing.T) { + iter := NewLookaheadIterator(func() int { return 42 }) + for i := 1; i <= 5; i++ { + iter.Next() + if iter.Count() != i { + t.Fatalf("after %d Next calls: Count() = %d", i, iter.Count()) + } + } +} + +func TestLookaheadIterator_Count_NotIncrementedByPeek(t *testing.T) { + iter := NewLookaheadIterator(func() int { return 1 }) + iter.Peek() + iter.Peek() + if iter.Count() != 0 { + t.Fatalf("Peek should not increment Count; got %d", iter.Count()) + } +} + +// ---------- ExpectNextValue -------------------------------------------------- + +func TestLookaheadIterator_ExpectNextValue_Match(t *testing.T) { + seq := []int{5, 10, 15} + idx := 0 + iter := NewLookaheadIterator(func() int { + v := seq[idx] + idx++ + return v + }) + + eq := func(a, b int) bool { return a == b } + + result := iter.ExpectNextValue(5, eq) + if result == nil { + t.Fatal("expected match, got nil") + } + if *result != 5 { + t.Fatalf("matched value = %d, want 5", *result) + } + if iter.Count() != 1 { + t.Fatalf("Count() = %d, want 1 (match should consume)", iter.Count()) + } +} + +func TestLookaheadIterator_ExpectNextValue_NoMatch(t *testing.T) { + seq := []int{5, 10} + idx := 0 + iter := NewLookaheadIterator(func() int { + v := seq[idx] + idx++ + return v + }) + + eq := func(a, b int) bool { return a == b } + + result := iter.ExpectNextValue(999, eq) + if result != nil { + t.Fatalf("expected nil for non-match, got %d", *result) + } + if iter.Count() != 0 { + t.Fatalf("Count() = %d, want 0 (non-match should not consume)", iter.Count()) + } + + // The element should still be available. + if got := iter.Next(); got != 5 { + t.Fatalf("Next() after failed expect = %d, want 5", got) + } +} + +func TestLookaheadIterator_ExpectNextValue_ConsecutiveMatches(t *testing.T) { + seq := []int{1, 2, 3} + idx := 0 + iter := NewLookaheadIterator(func() int { + v := seq[idx] + idx++ + return v + }) + + eq := func(a, b int) bool { return a == b } + + for i, expected := range seq { + result := iter.ExpectNextValue(expected, eq) + if result == nil { + t.Fatalf("step %d: expected match for %d, got nil", i, expected) + } + if *result != expected { + t.Fatalf("step %d: got %d, want %d", i, *result, expected) + } + } + if iter.Count() != 3 { + t.Fatalf("Count() = %d, want 3", iter.Count()) + } +} + +func TestLookaheadIterator_ExpectNextValue_FailThenSucceed(t *testing.T) { + seq := []int{1, 2} + idx := 0 + iter := NewLookaheadIterator(func() int { + v := seq[idx] + idx++ + return v + }) + + eq := func(a, b int) bool { return a == b } + + // Fail: looking for 2, but next is 1. + if r := iter.ExpectNextValue(2, eq); r != nil { + t.Fatalf("expected nil, got %d", *r) + } + // Succeed: looking for 1, and next is 1. + if r := iter.ExpectNextValue(1, eq); r == nil { + t.Fatal("expected match for 1, got nil") + } +} + +// ---------- ExpectNextMatches ------------------------------------------------ + +func TestLookaheadIterator_ExpectNextMatches_PredicateTrue(t *testing.T) { + iter := NewLookaheadIterator(func() int { return 42 }) + + result := iter.ExpectNextMatches(func(v int) bool { return v > 0 }) + if result == nil { + t.Fatal("expected match, got nil") + } + if *result != 42 { + t.Fatalf("matched = %d, want 42", *result) + } + if iter.Count() != 1 { + t.Fatalf("Count() = %d, want 1", iter.Count()) + } +} + +func TestLookaheadIterator_ExpectNextMatches_PredicateFalse(t *testing.T) { + iter := NewLookaheadIterator(func() int { return 42 }) + + result := iter.ExpectNextMatches(func(v int) bool { return v < 0 }) + if result != nil { + t.Fatalf("expected nil, got %d", *result) + } + if iter.Count() != 0 { + t.Fatalf("Count() = %d, want 0 (no consume on mismatch)", iter.Count()) + } +} + +func TestLookaheadIterator_ExpectNextMatches_PredicateCalledOnce(t *testing.T) { + iter := NewLookaheadIterator(func() int { return 1 }) + calls := 0 + iter.ExpectNextMatches(func(v int) bool { + calls++ + return false + }) + if calls != 1 { + t.Fatalf("predicate called %d times, want 1", calls) + } +} + +func TestLookaheadIterator_ExpectNextMatches_DoesNotConsumeOnMismatch(t *testing.T) { + seq := []string{"hello", "world"} + idx := 0 + iter := NewLookaheadIterator(func() string { + v := seq[idx] + idx++ + return v + }) + + // Mismatch. + result := iter.ExpectNextMatches(func(v string) bool { return v == "world" }) + if result != nil { + t.Fatalf("expected nil, got %q", *result) + } + + // "hello" should still be there. + got := iter.Next() + if got != "hello" { + t.Fatalf("Next() = %q, want 'hello' (should not have been consumed)", got) + } +} + +// ---------- Generic type support (strings) ----------------------------------- + +func TestLookaheadIterator_WithStrings(t *testing.T) { + words := []string{"SELECT", "FROM", "WHERE"} + idx := 0 + iter := NewLookaheadIterator(func() string { + v := words[idx] + idx++ + return v + }) + + if got := iter.Peek(); got != "SELECT" { + t.Fatalf("Peek() = %q, want 'SELECT'", got) + } + if got := iter.Next(); got != "SELECT" { + t.Fatalf("Next() = %q, want 'SELECT'", got) + } + if got := iter.Next(); got != "FROM" { + t.Fatalf("Next() = %q, want 'FROM'", got) + } + if got := iter.Peek(); got != "WHERE" { + t.Fatalf("Peek() = %q, want 'WHERE'", got) + } + if got := iter.Next(); got != "WHERE" { + t.Fatalf("Next() = %q, want 'WHERE'", got) + } + if iter.Count() != 3 { + t.Fatalf("Count() = %d, want 3", iter.Count()) + } +} + +// ---------- Integration: LookaheadIterator wrapping the Lexer ---------------- + +func TestLookaheadIterator_WithLexer(t *testing.T) { + l := NewLexer("SELECT * FROM t;") + iter := NewLookaheadIterator(func() Token { + tok, _ := l.NextToken() + return tok + }) + + // Peek should give SELECT. + peeked := iter.Peek() + if peeked.Type != TOKEN_SELECT { + t.Fatalf("Peek() type = %v, want SELECT", peeked.Type) + } + + // Next should consume the same SELECT. + got := iter.Next() + if got.Type != TOKEN_SELECT { + t.Fatalf("Next() type = %v, want SELECT", got.Type) + } + + // Next → STAR. + got = iter.Next() + if got.Type != TOKEN_STAR { + t.Fatalf("Next() type = %v, want STAR", got.Type) + } + + // Peek → FROM. + peeked = iter.Peek() + if peeked.Type != TOKEN_FROM { + t.Fatalf("Peek() type = %v, want FROM", peeked.Type) + } + + // ExpectNextValue should match FROM. + eq := func(a, b Token) bool { return a.Type == b.Type } + result := iter.ExpectNextValue(Token{Type: TOKEN_FROM}, eq) + if result == nil { + t.Fatal("expected FROM to match, got nil") + } + if result.Literal != "FROM" { + t.Fatalf("matched literal = %q, want 'FROM'", result.Literal) + } + + // ExpectNextMatches for an identifier. + result = iter.ExpectNextMatches(func(tok Token) bool { + return tok.Type == TOKEN_IDENT + }) + if result == nil { + t.Fatal("expected IDENT match, got nil") + } + if result.Literal != "t" { + t.Fatalf("matched literal = %q, want 't'", result.Literal) + } + + // SEMICOLON. + got = iter.Next() + if got.Type != TOKEN_SEMICOLON { + t.Fatalf("Next() type = %v, want SEMICOLON", got.Type) + } + + // EOF. + got = iter.Next() + if got.Type != TOKEN_EOF { + t.Fatalf("Next() type = %v, want EOF", got.Type) + } + + if iter.Count() != 6 { + t.Fatalf("Count() = %d, want 6", iter.Count()) + } +} + +func TestLookaheadIterator_ExpectNextValue_NoMatchDoesNotAdvanceLexer(t *testing.T) { + l := NewLexer("SELECT FROM") + iter := NewLookaheadIterator(func() Token { + tok, _ := l.NextToken() + return tok + }) + + eq := func(a, b Token) bool { return a.Type == b.Type } + + // Try to match FROM, but next is SELECT — should fail. + result := iter.ExpectNextValue(Token{Type: TOKEN_FROM}, eq) + if result != nil { + t.Fatal("expected nil, got a match") + } + + // SELECT should still be the next token. + got := iter.Next() + if got.Type != TOKEN_SELECT { + t.Fatalf("Next() after failed expect = %v, want SELECT", got.Type) + } +} + +// ---------- Edge: struct types with LookaheadIterator ------------------------ + +type testPair struct { + key string + value int +} + +func TestLookaheadIterator_WithStructs(t *testing.T) { + pairs := []testPair{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + } + idx := 0 + iter := NewLookaheadIterator(func() testPair { + v := pairs[idx] + idx++ + return v + }) + + // Peek. + peeked := iter.Peek() + if peeked.key != "a" || peeked.value != 1 { + t.Fatalf("Peek() = %+v, want {a 1}", peeked) + } + + // ExpectNextMatches with struct field check. + result := iter.ExpectNextMatches(func(p testPair) bool { + return p.key == "a" + }) + if result == nil { + t.Fatal("expected match, got nil") + } + + // Next. + got := iter.Next() + if got.key != "b" { + t.Fatalf("Next() = %+v, want key='b'", got) + } + + if iter.Count() != 2 { + t.Fatalf("Count() = %d, want 2", iter.Count()) + } +} diff --git a/internal/sql/lexer/tokens.go b/internal/sql/lexer/tokens.go new file mode 100644 index 0000000..b80679b --- /dev/null +++ b/internal/sql/lexer/tokens.go @@ -0,0 +1,229 @@ +package lexer + +import "fmt" + +// TokenType is an integer tag that identifies what kind of lexical unit a +// Token represents. Every terminal in the grammar maps to exactly one +// TokenType constant. +type TokenType int + +//nolint:revive // We prefer ALL_CAPS for token constants +const ( + // Special + TOKEN_EOF TokenType = iota // end of input; always the last token + TOKEN_ILLEGAL // unrecognised character; carries the raw byte + + // Literals + TOKEN_IDENT + TOKEN_INTEGER + TOKEN_FLOAT + TOKEN_STRING + + // DDL / database keywords + TOKEN_CREATE + TOKEN_DATABASE + TOKEN_USE + TOKEN_DROP + TOKEN_IF + TOKEN_EXISTS + TOKEN_TABLE + TOKEN_ALTER + TOKEN_ADD + TOKEN_COLUMN + TOKEN_MODIFY + TOKEN_RENAME + TOKEN_TO + + // DML keywords + TOKEN_SELECT + TOKEN_DISTINCT + TOKEN_ALL + TOKEN_FROM + TOKEN_WHERE + TOKEN_AS + TOKEN_INSERT + TOKEN_INTO + TOKEN_VALUES + TOKEN_UPDATE + TOKEN_SET + TOKEN_DELETE + + // JOIN keywords + TOKEN_JOIN + TOKEN_INNER + TOKEN_LEFT + TOKEN_RIGHT + TOKEN_FULL + TOKEN_OUTER + TOKEN_CROSS + TOKEN_ON + + // Clause keywords + TOKEN_GROUP + TOKEN_BY + TOKEN_HAVING + TOKEN_ORDER + TOKEN_ASC + TOKEN_DESC + TOKEN_LIMIT + TOKEN_OFFSET + + // Constraint / type keywords + TOKEN_PRIMARY + TOKEN_KEY + TOKEN_NOT + TOKEN_NULL + TOKEN_DEFAULT + TOKEN_UNIQUE + TOKEN_REFERENCES + + // Logical / predicate keywords + TOKEN_AND + TOKEN_OR + TOKEN_TRUE + TOKEN_FALSE + TOKEN_LIKE + TOKEN_IS + TOKEN_IN + TOKEN_BETWEEN + + // Data-type keywords + TOKEN_INT + TOKEN_BIGINT + TOKEN_VARCHAR + TOKEN_BOOLEAN + TOKEN_TEXT + TOKEN_TIMESTAMP + + // Comparison operators + TOKEN_EQ // = + TOKEN_NEQ // != or <> + TOKEN_LT // < + TOKEN_GT // > + TOKEN_LTE // <= + TOKEN_GTE // >= + + // Arithmetic operators + TOKEN_PLUS // + + TOKEN_MINUS // - + TOKEN_STAR // * + TOKEN_SLASH // / + TOKEN_PERCENT // % + + // Punctuation + TOKEN_LPAREN // ( + TOKEN_RPAREN // ) + TOKEN_COMMA // , + TOKEN_DOT // . + TOKEN_SEMICOLON // ; +) + +// tokenNames provides a human-readable label for each TokenType; used by +// String() and by test failure messages. +var tokenNames = map[TokenType]string{ + TOKEN_EOF: "EOF", + TOKEN_ILLEGAL: "ILLEGAL", + TOKEN_IDENT: "IDENT", + TOKEN_INTEGER: "INTEGER", + TOKEN_FLOAT: "FLOAT", + TOKEN_STRING: "STRING", + TOKEN_CREATE: "CREATE", + TOKEN_DATABASE: "DATABASE", + TOKEN_USE: "USE", + TOKEN_DROP: "DROP", + TOKEN_IF: "IF", + TOKEN_EXISTS: "EXISTS", + TOKEN_TABLE: "TABLE", + TOKEN_ALTER: "ALTER", + TOKEN_ADD: "ADD", + TOKEN_COLUMN: "COLUMN", + TOKEN_MODIFY: "MODIFY", + TOKEN_RENAME: "RENAME", + TOKEN_TO: "TO", + TOKEN_SELECT: "SELECT", + TOKEN_DISTINCT: "DISTINCT", + TOKEN_ALL: "ALL", + TOKEN_FROM: "FROM", + TOKEN_WHERE: "WHERE", + TOKEN_AS: "AS", + TOKEN_INSERT: "INSERT", + TOKEN_INTO: "INTO", + TOKEN_VALUES: "VALUES", + TOKEN_UPDATE: "UPDATE", + TOKEN_SET: "SET", + TOKEN_DELETE: "DELETE", + TOKEN_JOIN: "JOIN", + TOKEN_INNER: "INNER", + TOKEN_LEFT: "LEFT", + TOKEN_RIGHT: "RIGHT", + TOKEN_FULL: "FULL", + TOKEN_OUTER: "OUTER", + TOKEN_CROSS: "CROSS", + TOKEN_ON: "ON", + TOKEN_GROUP: "GROUP", + TOKEN_BY: "BY", + TOKEN_HAVING: "HAVING", + TOKEN_ORDER: "ORDER", + TOKEN_ASC: "ASC", + TOKEN_DESC: "DESC", + TOKEN_LIMIT: "LIMIT", + TOKEN_OFFSET: "OFFSET", + TOKEN_PRIMARY: "PRIMARY", + TOKEN_KEY: "KEY", + TOKEN_NOT: "NOT", + TOKEN_NULL: "NULL", + TOKEN_DEFAULT: "DEFAULT", + TOKEN_UNIQUE: "UNIQUE", + TOKEN_REFERENCES: "REFERENCES", + TOKEN_AND: "AND", + TOKEN_OR: "OR", + TOKEN_TRUE: "TRUE", + TOKEN_FALSE: "FALSE", + TOKEN_LIKE: "LIKE", + TOKEN_IS: "IS", + TOKEN_IN: "IN", + TOKEN_BETWEEN: "BETWEEN", + TOKEN_INT: "INT", + TOKEN_BIGINT: "BIGINT", + TOKEN_VARCHAR: "VARCHAR", + TOKEN_BOOLEAN: "BOOLEAN", + TOKEN_TEXT: "TEXT", + TOKEN_TIMESTAMP: "TIMESTAMP", + TOKEN_EQ: "=", + TOKEN_NEQ: "!=", + TOKEN_LT: "<", + TOKEN_GT: ">", + TOKEN_LTE: "<=", + TOKEN_GTE: ">=", + TOKEN_PLUS: "+", + TOKEN_MINUS: "-", + TOKEN_STAR: "*", + TOKEN_SLASH: "/", + TOKEN_PERCENT: "%", + TOKEN_LPAREN: "(", + TOKEN_RPAREN: ")", + TOKEN_COMMA: ",", + TOKEN_DOT: ".", + TOKEN_SEMICOLON: ";", +} + +// String returns the human-readable name of a TokenType. +func (t TokenType) String() string { + if s, ok := tokenNames[t]; ok { + return s + } + return fmt.Sprintf("TokenType(%d)", int(t)) +} + +// Token is a single lexical unit produced by the Lexer. +type Token struct { + Type TokenType // what kind of token this is + Literal string // raw source text (string tokens have quotes stripped and + // escapes decoded; keywords preserve their original casing) + Line int // 1-based line number of the token's first character + Col int // 1-based column number of the token's first character +} + +func (t Token) String() string { + return fmt.Sprintf("Token{%-12s %q %d:%d}", t.Type, t.Literal, t.Line, t.Col) +} diff --git a/internal/sql/lexer/tokens_test.go b/internal/sql/lexer/tokens_test.go new file mode 100644 index 0000000..cadb4fe --- /dev/null +++ b/internal/sql/lexer/tokens_test.go @@ -0,0 +1,149 @@ +package lexer + +import ( + "fmt" + "testing" +) + +// ---------- TokenType.String() ----------------------------------------------- + +func TestTokenType_String_KnownTypes(t *testing.T) { + // Every entry in the tokenNames map should be returned by String(). + for tt, name := range tokenNames { + t.Run(name, func(t *testing.T) { + got := tt.String() + if got != name { + t.Errorf("TokenType(%d).String() = %q, want %q", int(tt), got, name) + } + }) + } +} + +func TestTokenType_String_UnknownType(t *testing.T) { + unknown := TokenType(9999) + got := unknown.String() + want := fmt.Sprintf("TokenType(%d)", 9999) + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestTokenType_String_AllTokenTypesHaveNames(t *testing.T) { + // Walk through the iota range to ensure no gaps in the tokenNames map. + // This uses the fact that all token types are contiguous iota values + // from TOKEN_EOF (0) to TOKEN_SEMICOLON. + allTokenTypes := []TokenType{ + TOKEN_EOF, TOKEN_ILLEGAL, + TOKEN_IDENT, TOKEN_INTEGER, TOKEN_FLOAT, TOKEN_STRING, + TOKEN_CREATE, TOKEN_DATABASE, TOKEN_USE, TOKEN_DROP, TOKEN_IF, + TOKEN_EXISTS, TOKEN_TABLE, TOKEN_ALTER, TOKEN_ADD, TOKEN_COLUMN, + TOKEN_MODIFY, TOKEN_RENAME, TOKEN_TO, + TOKEN_SELECT, TOKEN_DISTINCT, TOKEN_ALL, TOKEN_FROM, TOKEN_WHERE, + TOKEN_AS, TOKEN_INSERT, TOKEN_INTO, TOKEN_VALUES, TOKEN_UPDATE, + TOKEN_SET, TOKEN_DELETE, + TOKEN_JOIN, TOKEN_INNER, TOKEN_LEFT, TOKEN_RIGHT, TOKEN_FULL, + TOKEN_OUTER, TOKEN_CROSS, TOKEN_ON, + TOKEN_GROUP, TOKEN_BY, TOKEN_HAVING, TOKEN_ORDER, TOKEN_ASC, + TOKEN_DESC, TOKEN_LIMIT, TOKEN_OFFSET, + TOKEN_PRIMARY, TOKEN_KEY, TOKEN_NOT, TOKEN_NULL, TOKEN_DEFAULT, + TOKEN_UNIQUE, TOKEN_REFERENCES, + TOKEN_AND, TOKEN_OR, TOKEN_TRUE, TOKEN_FALSE, TOKEN_LIKE, + TOKEN_IS, TOKEN_IN, TOKEN_BETWEEN, + TOKEN_INT, TOKEN_BIGINT, TOKEN_VARCHAR, TOKEN_BOOLEAN, TOKEN_TEXT, + TOKEN_TIMESTAMP, + TOKEN_EQ, TOKEN_NEQ, TOKEN_LT, TOKEN_GT, TOKEN_LTE, TOKEN_GTE, + TOKEN_PLUS, TOKEN_MINUS, TOKEN_STAR, TOKEN_SLASH, TOKEN_PERCENT, + TOKEN_LPAREN, TOKEN_RPAREN, TOKEN_COMMA, TOKEN_DOT, TOKEN_SEMICOLON, + } + for _, tt := range allTokenTypes { + name := tt.String() + // The fallback is "TokenType()". If we see that, the map is incomplete. + if name == fmt.Sprintf("TokenType(%d)", int(tt)) { + t.Errorf("TokenType %d has no human-readable name in tokenNames", int(tt)) + } + } +} + +// ---------- Token.String() --------------------------------------------------- + +func TestToken_String(t *testing.T) { + tests := []struct { + token Token + want string + }{ + { + Token{Type: TOKEN_SELECT, Literal: "SELECT", Line: 1, Col: 1}, + `Token{SELECT "SELECT" 1:1}`, + }, + { + Token{Type: TOKEN_INTEGER, Literal: "42", Line: 3, Col: 15}, + `Token{INTEGER "42" 3:15}`, + }, + { + Token{Type: TOKEN_STRING, Literal: "hello", Line: 1, Col: 10}, + `Token{STRING "hello" 1:10}`, + }, + { + Token{Type: TOKEN_EOF, Literal: "", Line: 5, Col: 1}, + `Token{EOF "" 5:1}`, + }, + { + Token{Type: TOKEN_ILLEGAL, Literal: "@", Line: 1, Col: 1}, + `Token{ILLEGAL "@" 1:1}`, + }, + } + for _, tc := range tests { + t.Run(tc.want, func(t *testing.T) { + got := tc.token.String() + if got != tc.want { + t.Errorf("got %q\nwant %q", got, tc.want) + } + }) + } +} + +// ---------- lookupIdent (keywords.go) ---------------------------------------- + +func TestLookupIdent_ReturnsKeywordType(t *testing.T) { + for word, expected := range keywords { + got := lookupIdent(word) + if got != expected { + t.Errorf("lookupIdent(%q) = %v, want %v", word, got, expected) + } + } +} + +func TestLookupIdent_CaseInsensitive(t *testing.T) { + tests := []struct { + input string + want TokenType + }{ + {"select", TOKEN_SELECT}, + {"SELECT", TOKEN_SELECT}, + {"SeLeCt", TOKEN_SELECT}, + {"from", TOKEN_FROM}, + {"FROM", TOKEN_FROM}, + {"fRoM", TOKEN_FROM}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + got := lookupIdent(tc.input) + if got != tc.want { + t.Errorf("lookupIdent(%q) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} + +func TestLookupIdent_ReturnsIdentForNonKeywords(t *testing.T) { + nonKeywords := []string{ + "foo", "bar", "my_table", "userId", "x", "_private", + "selection", "fromage", "orderly", "deleteme", + } + for _, word := range nonKeywords { + got := lookupIdent(word) + if got != TOKEN_IDENT { + t.Errorf("lookupIdent(%q) = %v, want TOKEN_IDENT", word, got) + } + } +}