diff --git a/.gitignore b/.gitignore index 316b294..110e28a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Build outputs erdn +erdn-mcp *.wasm # Website build artifacts (generated during CI) diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 0000000..d648bc5 --- /dev/null +++ b/.mcp.json @@ -0,0 +1,9 @@ +{ + "mcpServers": { + "erdn-lang": { + "type": "stdio", + "command": "go", + "args": ["run", "github.com/headercat/erdn-lang/cmd/erdn-mcp@latest"] + } + } +} diff --git a/README.md b/README.md index e3dc0f5..3865bda 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Maintaining ERD diagrams is tedious: graphical tools produce binary files that a - [Installation](#installation) - [Usage](#usage) + - [MCP server](#mcp-server) - [Syntax](#syntax) - [Comments](#comments) - [Tables](#tables) @@ -34,7 +35,11 @@ You need [Go 1.21](https://go.dev/dl/) or later. ### Install via `go install` ```sh +# CLI tool go install github.com/headercat/erdn-lang/cmd/erdn@latest + +# MCP server +go install github.com/headercat/erdn-lang/cmd/erdn-mcp@latest ``` ### Build from source @@ -99,6 +104,57 @@ erdn sql schema.erdn --dbms postgresql erdn sql schema.erdn --dbms mssql --out migrations/001_init.sql ``` +### MCP server + +**erdn-lang** ships a local [Model Context Protocol](https://modelcontextprotocol.io/) server so AI assistants and MCP-compatible editors can convert SQL schemas to ERDN and generate diagrams. + +#### Install + +```sh +go install github.com/headercat/erdn-lang/cmd/erdn-mcp@latest +``` +```json +{ + "mcpServers": { + "erdn-lang": { + "type": "stdio", + "command": "go", + "args": ["run", "github.com/headercat/erdn-lang/cmd/erdn-mcp@latest"] + } + } +} +``` + +Copy this block into your MCP client's configuration file, or use the `.mcp.json` at the root of this repository if your client supports auto-discovery. + +If you have already run `go install` above, you can replace the `"command"/"args"` with the installed binary: + +```json +{ + "mcpServers": { + "erdn-lang": { + "type": "stdio", + "command": "erdn-mcp" + } + } +} +``` + +#### Running manually + +If you do have a local clone you can also run the server directly: + +```sh +go run ./cmd/erdn-mcp +``` + +#### Tools + +The server exposes two tools: + +- `convert_sql_to_erdn` — converts SQL `CREATE TABLE` / `FOREIGN KEY` schema text into ERDN source. +- `generate_svg_from_erdn` — validates ERDN and returns generated SVG diagram text. + The generated SQL includes: - `CREATE TABLE` statements with DBMS-appropriate types and constraints. diff --git a/cmd/erdn-mcp/main.go b/cmd/erdn-mcp/main.go new file mode 100644 index 0000000..753fb28 --- /dev/null +++ b/cmd/erdn-mcp/main.go @@ -0,0 +1,270 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "os" + "strconv" + "strings" + + "github.com/headercat/erdn-lang/internal/parser" + "github.com/headercat/erdn-lang/internal/render" + "github.com/headercat/erdn-lang/internal/semantic" + "github.com/headercat/erdn-lang/internal/sqlimport" +) + +type rpcRequest struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type rpcResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Result interface{} `json:"result,omitempty"` + Error *rpcError `json:"error,omitempty"` +} + +type rpcError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func main() { + br := bufio.NewReader(os.Stdin) + bw := bufio.NewWriter(os.Stdout) + for { + msg, err := readMessage(br) + if err != nil { + if err == io.EOF { + return + } + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32700, Message: "parse error"}, + }) + continue + } + + var req rpcRequest + if err := json.Unmarshal(msg, &req); err != nil { + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32600, Message: "invalid request"}, + }) + continue + } + + // Notifications do not require responses (missing or null id). + if isNotificationID(req.ID) { + continue + } + id := decodeID(req.ID) + + switch req.Method { + case "initialize": + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + ID: id, + Result: map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{ + "tools": map[string]interface{}{}, + }, + "serverInfo": map[string]string{ + "name": "erdn-lang-mcp", + "version": "0.1.0", + }, + }, + }) + case "tools/list": + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + ID: id, + Result: map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "convert_sql_to_erdn", + "description": "Convert SQL CREATE TABLE/FK schema text to ERDN source code.", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "sql": map[string]string{ + "type": "string", + "description": "SQL DDL text containing CREATE TABLE statements.", + }, + }, + "required": []string{"sql"}, + }, + }, + { + "name": "generate_svg_from_erdn", + "description": "Validate ERDN source and generate an SVG diagram.", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "erdn": map[string]string{ + "type": "string", + "description": "ERDN schema source text.", + }, + }, + "required": []string{"erdn"}, + }, + }, + }, + }, + }) + case "tools/call": + res, rerr := handleToolCall(req.Params) + if rerr != nil { + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + ID: id, + Result: map[string]interface{}{ + "content": []map[string]string{ + {"type": "text", "text": rerr.Error()}, + }, + "isError": true, + }, + }) + continue + } + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + ID: id, + Result: res, + }) + default: + _ = writeJSON(bw, rpcResponse{ + JSONRPC: "2.0", + ID: id, + Error: &rpcError{Code: -32601, Message: "method not found"}, + }) + } + } +} + +func handleToolCall(params json.RawMessage) (map[string]interface{}, error) { + var call struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` + } + if err := json.Unmarshal(params, &call); err != nil { + return nil, fmt.Errorf("invalid tools/call params: %w", err) + } + + switch call.Name { + case "convert_sql_to_erdn": + sql, _ := call.Arguments["sql"].(string) + if strings.TrimSpace(sql) == "" { + return nil, fmt.Errorf("sql is required") + } + prog, err := sqlimport.ParseDDL(sql) + if err != nil { + return nil, err + } + erdn := sqlimport.ToERDN(prog) + return map[string]interface{}{ + "content": []map[string]string{ + {"type": "text", "text": erdn}, + }, + }, nil + case "generate_svg_from_erdn": + src, _ := call.Arguments["erdn"].(string) + if strings.TrimSpace(src) == "" { + return nil, fmt.Errorf("erdn is required") + } + prog, err := parser.ParseString(src) + if err != nil { + return nil, err + } + if errs := semantic.Validate(prog); len(errs) > 0 { + var lines []string + for _, e := range errs { + lines = append(lines, e.Error()) + } + return nil, fmt.Errorf(strings.Join(lines, "\n")) + } + svg := render.GenerateSVG(prog) + return map[string]interface{}{ + "content": []map[string]string{ + {"type": "text", "text": svg}, + }, + }, nil + default: + return nil, fmt.Errorf("unknown tool: %s", call.Name) + } +} + +func decodeID(raw json.RawMessage) interface{} { + var v interface{} + if err := json.Unmarshal(raw, &v); err != nil { + return string(raw) + } + return v +} + +func isNotificationID(raw json.RawMessage) bool { + if len(raw) == 0 { + return true + } + var v interface{} + if err := json.Unmarshal(raw, &v); err != nil { + return false + } + return v == nil +} + +func readMessage(r *bufio.Reader) ([]byte, error) { + headers := map[string]string{} + for { + line, err := r.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid header line") + } + headers[strings.ToLower(strings.TrimSpace(parts[0]))] = strings.TrimSpace(parts[1]) + } + cl, ok := headers["content-length"] + if !ok { + return nil, fmt.Errorf("missing content-length header") + } + n, err := strconv.Atoi(cl) + if err != nil { + return nil, fmt.Errorf("invalid content-length %q: %w", cl, err) + } + if n < 0 { + return nil, fmt.Errorf("invalid content-length %d: must be >= 0", n) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return buf, nil +} + +func writeJSON(w *bufio.Writer, payload interface{}) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + var frame bytes.Buffer + fmt.Fprintf(&frame, "Content-Length: %d\r\n\r\n", len(body)) + frame.Write(body) + if _, err := w.Write(frame.Bytes()); err != nil { + return err + } + return w.Flush() +} diff --git a/internal/sqlimport/sqlimport.go b/internal/sqlimport/sqlimport.go new file mode 100644 index 0000000..1dcd4a4 --- /dev/null +++ b/internal/sqlimport/sqlimport.go @@ -0,0 +1,598 @@ +package sqlimport + +import ( + "fmt" + "strings" + "unicode" + + "github.com/headercat/erdn-lang/internal/ast" + "github.com/headercat/erdn-lang/internal/parser" +) + +const defaultColumnType = "text" + +// ParseDDL parses SQL CREATE TABLE statements and converts them into an ERDN AST. +// It supports table-level/inline primary keys and foreign keys. +func ParseDDL(sql string) (*ast.Program, error) { + clean := stripSQLComments(sql) + + var tables []*ast.Table + var links []*ast.Link + seenLinks := map[string]bool{} + + rest := clean + for { + idx := indexCreateTable(rest) + if idx < 0 { + break + } + rest = rest[idx+len("create table"):] + + tableName, afterName, err := parseIdentifier(rest) + if err != nil { + return nil, fmt.Errorf("parse table name: %w", err) + } + rest = afterName + tableName = sanitizeIdentifier(baseIdentifier(tableName)) + + open := strings.Index(rest, "(") + if open < 0 { + return nil, fmt.Errorf("missing '(' after CREATE TABLE %s", tableName) + } + rest = rest[open:] + body, afterBody, err := extractParenthesized(rest) + if err != nil { + return nil, fmt.Errorf("parse table body for %s: %w", tableName, err) + } + rest = afterBody + + tbl, tblLinks, err := parseTableBody(tableName, body) + if err != nil { + return nil, err + } + tables = append(tables, tbl) + for _, l := range tblLinks { + key := fmt.Sprintf("%s.%s>%s.%s", l.FromTable, l.FromColumn, l.ToTable, l.ToColumn) + if seenLinks[key] { + continue + } + seenLinks[key] = true + links = append(links, l) + } + } + + return &ast.Program{Tables: tables, Links: links}, nil +} + +// ToERDN formats an ERDN AST as textual erdn-lang source. +func ToERDN(prog *ast.Program) string { + var b strings.Builder + for i, t := range prog.Tables { + if i > 0 { + b.WriteString("\n\n") + } + fmt.Fprintf(&b, "table %s (\n", t.Name) + for _, c := range t.Columns { + fmt.Fprintf(&b, " %s %s", c.Name, parser.FormatType(c)) + if mods := formatModifiers(c.Modifiers); mods != "" { + b.WriteByte(' ') + b.WriteString(mods) + } + b.WriteByte('\n') + } + b.WriteString(")") + } + + if len(prog.Links) > 0 { + b.WriteString("\n\n") + for i, l := range prog.Links { + if i > 0 { + b.WriteByte('\n') + } + fmt.Fprintf(&b, "link %s %s.%s to %s %s.%s", + cardToText(l.FromCardinality), l.FromTable, l.FromColumn, + cardToText(l.ToCardinality), l.ToTable, l.ToColumn) + } + } + + b.WriteByte('\n') + return b.String() +} + +func parseTableBody(tableName, body string) (*ast.Table, []*ast.Link, error) { + parts := splitTopLevelComma(body) + tbl := &ast.Table{Name: tableName} + + var links []*ast.Link + pkCols := map[string]bool{} + + for _, raw := range parts { + part := strings.TrimSpace(raw) + if part == "" { + continue + } + lower := strings.ToLower(part) + + if strings.HasPrefix(lower, "primary key") { + for _, c := range parseColumnList(part) { + pkCols[sanitizeIdentifier(baseIdentifier(c))] = true + } + continue + } + if strings.HasPrefix(lower, "constraint") && strings.Contains(lower, " primary key") { + for _, c := range parseColumnList(part) { + pkCols[sanitizeIdentifier(baseIdentifier(c))] = true + } + continue + } + + if strings.HasPrefix(lower, "foreign key") { + localCols := parseColumnList(part) + refTable, refCols := parseFKReference(part) + refTable = sanitizeIdentifier(baseIdentifier(refTable)) + n := len(localCols) + if len(refCols) < n { + n = len(refCols) + } + for i := 0; i < n; i++ { + links = append(links, &ast.Link{ + FromTable: sanitizeIdentifier(baseIdentifier(refTable)), + FromColumn: sanitizeIdentifier(baseIdentifier(refCols[i])), + ToTable: tableName, + ToColumn: sanitizeIdentifier(baseIdentifier(localCols[i])), + FromCardinality: ast.CardOne, + ToCardinality: ast.CardMany, + }) + } + continue + } + if strings.HasPrefix(lower, "constraint") && strings.Contains(lower, " foreign key") { + localCols := parseColumnList(part) + refTable, refCols := parseFKReference(part) + refTable = sanitizeIdentifier(baseIdentifier(refTable)) + n := len(localCols) + if len(refCols) < n { + n = len(refCols) + } + for i := 0; i < n; i++ { + links = append(links, &ast.Link{ + FromTable: sanitizeIdentifier(baseIdentifier(refTable)), + FromColumn: sanitizeIdentifier(baseIdentifier(refCols[i])), + ToTable: tableName, + ToColumn: sanitizeIdentifier(baseIdentifier(localCols[i])), + FromCardinality: ast.CardOne, + ToCardinality: ast.CardMany, + }) + } + continue + } + + col, inlineRef := parseColumnDef(part) + if col == nil { + continue + } + tbl.Columns = append(tbl.Columns, col) + if inlineRef != nil { + links = append(links, &ast.Link{ + FromTable: inlineRef.table, + FromColumn: inlineRef.column, + ToTable: tableName, + ToColumn: col.Name, + FromCardinality: ast.CardOne, + ToCardinality: ast.CardMany, + }) + } + } + + for _, c := range tbl.Columns { + if pkCols[c.Name] { + c.Modifiers = append(c.Modifiers, ast.Modifier{Kind: ast.ModPrimaryKey}) + } + } + + return tbl, links, nil +} + +type ref struct { + table string + column string +} + +func parseColumnDef(part string) (*ast.Column, *ref) { + name, rest, err := parseIdentifier(part) + if err != nil { + return nil, nil + } + name = sanitizeIdentifier(baseIdentifier(name)) + rest = strings.TrimSpace(rest) + if rest == "" { + return nil, nil + } + + typeExpr, modifiersExpr := splitTypeAndModifiers(rest) + col := &ast.Column{Name: name} + populateType(col, typeExpr) + + lower := strings.ToLower(modifiersExpr) + if strings.Contains(lower, "primary key") { + col.Modifiers = append(col.Modifiers, ast.Modifier{Kind: ast.ModPrimaryKey}) + } + if strings.Contains(lower, "not null") { + col.Modifiers = append(col.Modifiers, ast.Modifier{Kind: ast.ModNotNull}) + } else if strings.Contains(lower, " null") { + col.Modifiers = append(col.Modifiers, ast.Modifier{Kind: ast.ModNullable}) + } + if strings.Contains(lower, "auto_increment") || + strings.Contains(lower, " identity") || + strings.Contains(lower, "generated always as identity") { + col.Modifiers = append(col.Modifiers, ast.Modifier{Kind: ast.ModAutoIncrement}) + } + if containsWord(lower, "index") { + col.Modifiers = append(col.Modifiers, ast.Modifier{Kind: ast.ModIndexed}) + } + + if def := parseDefault(modifiersExpr); def != "" { + col.Modifiers = append(col.Modifiers, ast.Modifier{Kind: ast.ModDefault, Value: def}) + } + + if strings.Contains(lower, "references ") { + tbl, cols := parseFKReference(part) + if tbl != "" && len(cols) > 0 { + return col, &ref{ + table: sanitizeIdentifier(baseIdentifier(tbl)), + column: sanitizeIdentifier(baseIdentifier(cols[0])), + } + } + } + + return col, nil +} + +func parseDefault(s string) string { + lower := strings.ToLower(s) + i := strings.Index(lower, " default ") + if i < 0 { + return "" + } + after := strings.TrimSpace(s[i+len(" default "):]) + if after == "" { + return "" + } + + if after[0] == '\'' { + end := strings.Index(after[1:], "'") + if end >= 0 { + return `"` + after[1:1+end] + `"` + } + } + if after[0] == '"' { + end := strings.Index(after[1:], `"`) + if end >= 0 { + return `"` + after[1:1+end] + `"` + } + } + upper := strings.ToUpper(after) + if strings.HasPrefix(upper, "CURRENT_TIMESTAMP") { + return "NOW()" + } + if strings.HasPrefix(upper, "NOW()") { + return "NOW()" + } + + for i := 0; i < len(after); i++ { + if unicode.IsSpace(rune(after[i])) || after[i] == ',' { + return after[:i] + } + } + return after +} + +func populateType(col *ast.Column, typeExpr string) { + typeExpr = strings.TrimSpace(typeExpr) + if typeExpr == "" { + col.Type = defaultColumnType + return + } + open := strings.Index(typeExpr, "(") + if open < 0 || !strings.HasSuffix(typeExpr, ")") { + col.Type = strings.ToLower(strings.TrimSpace(typeExpr)) + return + } + col.Type = strings.ToLower(strings.TrimSpace(typeExpr[:open])) + params := strings.TrimSpace(typeExpr[open+1 : len(typeExpr)-1]) + if params == "" { + return + } + for _, p := range strings.Split(params, ",") { + col.TypeParams = append(col.TypeParams, strings.TrimSpace(p)) + } +} + +func splitTypeAndModifiers(s string) (string, string) { + var b strings.Builder + depth := 0 + r := []rune(s) + for i := 0; i < len(r); i++ { + ch := r[i] + if ch == '(' { + depth++ + } else if ch == ')' && depth > 0 { + depth-- + } + if depth == 0 && unicode.IsSpace(ch) { + remaining := strings.TrimSpace(string(r[i:])) + lower := strings.ToLower(remaining) + if startsConstraint(lower) { + return strings.TrimSpace(b.String()), remaining + } + } + b.WriteRune(ch) + } + return strings.TrimSpace(b.String()), "" +} + +func startsConstraint(s string) bool { + for _, k := range []string{ + "not null", "null", "default", "primary key", "unique", "references", + "constraint", "auto_increment", "identity", "generated", + } { + if strings.HasPrefix(s, k) { + return true + } + } + return false +} + +func parseFKReference(s string) (string, []string) { + lower := strings.ToLower(s) + i := strings.Index(lower, "references") + if i < 0 { + return "", nil + } + rest := strings.TrimSpace(s[i+len("references"):]) + table, rest, err := parseIdentifier(rest) + if err != nil { + return "", nil + } + cols := parseColumnList(rest) + return table, cols +} + +func parseColumnList(s string) []string { + open := strings.Index(s, "(") + if open < 0 { + return nil + } + body, _, err := extractParenthesized(s[open:]) + if err != nil { + return nil + } + parts := splitTopLevelComma(body) + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + name, _, err := parseIdentifier(p) + if err != nil { + continue + } + out = append(out, name) + } + return out +} + +func stripSQLComments(s string) string { + var out []string + for _, line := range strings.Split(s, "\n") { + if idx := strings.Index(line, "--"); idx >= 0 { + line = line[:idx] + } + out = append(out, line) + } + return strings.Join(out, "\n") +} + +func indexCreateTable(s string) int { + return strings.Index(strings.ToLower(s), "create table") +} + +func extractParenthesized(s string) (inside, rest string, err error) { + if len(s) == 0 || s[0] != '(' { + return "", "", fmt.Errorf("expected '('") + } + depth := 0 + inSingle := false + inDouble := false + r := []rune(s) + for i, ch := range r { + switch ch { + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + } + if inSingle || inDouble { + continue + } + if ch == '(' { + depth++ + } + if ch == ')' { + depth-- + if depth == 0 { + return string(r[1:i]), string(r[i+1:]), nil + } + } + } + return "", "", fmt.Errorf("unterminated parenthesized block") +} + +func splitTopLevelComma(s string) []string { + var parts []string + depth := 0 + inSingle := false + inDouble := false + start := 0 + r := []rune(s) + for i, ch := range r { + switch ch { + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + } + if inSingle || inDouble { + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + if depth > 0 { + depth-- + } + continue + } + if ch == ',' && depth == 0 { + parts = append(parts, string(r[start:i])) + start = i + 1 + } + } + parts = append(parts, string(r[start:])) + return parts +} + +func parseIdentifier(s string) (ident, rest string, err error) { + s = strings.TrimLeftFunc(s, unicode.IsSpace) + if s == "" { + return "", "", fmt.Errorf("empty identifier") + } + switch s[0] { + case '`': + end := strings.Index(s[1:], "`") + if end < 0 { + return "", "", fmt.Errorf("unterminated backtick identifier") + } + return s[1 : 1+end], s[2+end:], nil + case '"': + end := strings.Index(s[1:], `"`) + if end < 0 { + return "", "", fmt.Errorf("unterminated quoted identifier") + } + return s[1 : 1+end], s[2+end:], nil + case '[': + end := strings.Index(s[1:], "]") + if end < 0 { + return "", "", fmt.Errorf("unterminated bracket identifier") + } + return s[1 : 1+end], s[2+end:], nil + default: + i := 0 + for i < len(s) { + ch := s[i] + if unicode.IsSpace(rune(ch)) || ch == '(' || ch == ')' || ch == ',' { + break + } + i++ + } + if i == 0 { + return "", "", fmt.Errorf("invalid identifier") + } + return s[:i], s[i:], nil + } +} + +func baseIdentifier(s string) string { + s = strings.TrimSpace(s) + if idx := strings.LastIndex(s, "."); idx >= 0 { + return strings.TrimSpace(s[idx+1:]) + } + return s +} + +func sanitizeIdentifier(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "_" + } + var out []rune + for i, r := range []rune(s) { + ok := isWordChar(r) + if i == 0 && !(unicode.IsLetter(r) || r == '_') { + ok = false + } + if ok { + out = append(out, r) + } else { + if len(out) == 0 || out[len(out)-1] != '_' { + out = append(out, '_') + } + } + } + result := strings.Trim(string(out), "_") + if result == "" { + return "_" + } + return result +} + +func containsWord(s, word string) bool { + start := 0 + for { + i := strings.Index(s[start:], word) + if i < 0 { + return false + } + i += start + beforeOK := i == 0 || !isWordChar(rune(s[i-1])) + after := i + len(word) + afterOK := after >= len(s) || !isWordChar(rune(s[after])) + if beforeOK && afterOK { + return true + } + start = i + len(word) + } +} + +func isWordChar(r rune) bool { + return unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' +} + +func cardToText(c ast.Cardinality) string { + if c == ast.CardMany { + return "many" + } + return "one" +} + +func formatModifiers(mods []ast.Modifier) string { + out := make([]string, 0, len(mods)) + for _, m := range mods { + switch m.Kind { + case ast.ModPrimaryKey: + out = append(out, "primary-key") + case ast.ModNullable: + out = append(out, "nullable") + case ast.ModNotNull: + out = append(out, "not-null") + case ast.ModAutoIncrement: + out = append(out, "auto-increment") + case ast.ModIndexed: + out = append(out, "indexed") + case ast.ModDefault: + out = append(out, "default("+m.Value+")") + } + } + return strings.Join(out, " ") +} diff --git a/internal/sqlimport/sqlimport_test.go b/internal/sqlimport/sqlimport_test.go new file mode 100644 index 0000000..fa7034e --- /dev/null +++ b/internal/sqlimport/sqlimport_test.go @@ -0,0 +1,106 @@ +package sqlimport + +import ( + "strings" + "testing" + + "github.com/headercat/erdn-lang/internal/ast" + "github.com/headercat/erdn-lang/internal/semantic" +) + +func TestParseDDL_BasicTableAndForeignKey(t *testing.T) { + sql := ` +CREATE TABLE users ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + username VARCHAR(64) NOT NULL UNIQUE +); + +CREATE TABLE posts ( + id BIGINT PRIMARY KEY, + author_id BIGINT NOT NULL, + title VARCHAR(255) NOT NULL, + CONSTRAINT fk_posts_author FOREIGN KEY (author_id) REFERENCES users(id) +);` + + prog, err := ParseDDL(sql) + if err != nil { + t.Fatalf("ParseDDL error: %v", err) + } + if len(prog.Tables) != 2 { + t.Fatalf("expected 2 tables, got %d", len(prog.Tables)) + } + if len(prog.Links) != 1 { + t.Fatalf("expected 1 link, got %d", len(prog.Links)) + } + link := prog.Links[0] + if link.FromTable != "users" { + t.Fatalf("unexpected from table: %q", link.FromTable) + } + if link.FromColumn != "id" { + t.Fatalf("unexpected from column: %q", link.FromColumn) + } + if link.ToTable != "posts" { + t.Fatalf("unexpected to table: %q", link.ToTable) + } + if link.ToColumn != "author_id" { + t.Fatalf("unexpected to column: %q", link.ToColumn) + } + if link.FromCardinality != ast.CardOne { + t.Fatalf("unexpected from cardinality: %v", link.FromCardinality) + } + if link.ToCardinality != ast.CardMany { + t.Fatalf("unexpected to cardinality: %v", link.ToCardinality) + } + if errs := semantic.Validate(prog); len(errs) > 0 { + t.Fatalf("semantic errors: %v", errs) + } + + erdn := ToERDN(prog) + if !strings.Contains(erdn, "table users (") { + t.Fatalf("missing users table in ERDN:\n%s", erdn) + } + if !strings.Contains(erdn, "link one users.id to many posts.author_id") { + t.Fatalf("missing expected link in ERDN:\n%s", erdn) + } +} + +func TestParseDDL_QuotedAndNonLatinIdentifiers(t *testing.T) { + sql := ` +CREATE TABLE "用户" ( + "编号" BIGINT PRIMARY KEY, + "名称" VARCHAR(128) NOT NULL +);` + + prog, err := ParseDDL(sql) + if err != nil { + t.Fatalf("ParseDDL error: %v", err) + } + if len(prog.Tables) != 1 { + t.Fatalf("expected 1 table, got %d", len(prog.Tables)) + } + if prog.Tables[0].Name != "用户" { + t.Fatalf("expected table name 用户, got %q", prog.Tables[0].Name) + } + if len(prog.Tables[0].Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(prog.Tables[0].Columns)) + } +} + +func TestParseDDL_InlineReference(t *testing.T) { + sql := ` +CREATE TABLE customers ( + id BIGINT PRIMARY KEY +); +CREATE TABLE orders ( + id BIGINT PRIMARY KEY, + customer_id BIGINT REFERENCES customers(id) +);` + + prog, err := ParseDDL(sql) + if err != nil { + t.Fatalf("ParseDDL error: %v", err) + } + if len(prog.Links) != 1 { + t.Fatalf("expected 1 link, got %d", len(prog.Links)) + } +} diff --git a/website/.vitepress/config.mts b/website/.vitepress/config.mts index 98e0ffb..ddeae68 100644 --- a/website/.vitepress/config.mts +++ b/website/.vitepress/config.mts @@ -10,6 +10,7 @@ export default defineConfig({ { text: "Home", link: "/" }, { text: "Guide", link: "/guide" }, { text: "Syntax Specification", link: "/syntax" }, + { text: "MCP Server", link: "/mcp" }, { text: "Playground", link: "/playground" }, ], sidebar: [ @@ -20,6 +21,12 @@ export default defineConfig({ { text: "Syntax Specification", link: "/syntax" }, ], }, + { + text: "Integrations", + items: [ + { text: "MCP Server", link: "/mcp" }, + ], + }, ], socialLinks: [ { icon: "github", link: "https://github.com/headercat/erdn-lang" }, diff --git a/website/index.md b/website/index.md index 87fceec..ffa4836 100644 --- a/website/index.md +++ b/website/index.md @@ -26,4 +26,7 @@ features: - icon: 💬 title: Readable details: "# comments on tables, columns, and links are rendered as subtitle rows directly in the diagram." + - icon: 🤖 + title: MCP Server + details: AI-ready — connect Claude, Cursor, or any MCP-compatible assistant to convert SQL and generate diagrams directly from your editor. --- diff --git a/website/mcp.md b/website/mcp.md new file mode 100644 index 0000000..462f868 --- /dev/null +++ b/website/mcp.md @@ -0,0 +1,108 @@ +# MCP Server + +**erdn-lang** ships a built-in [Model Context Protocol](https://modelcontextprotocol.io/) (MCP) server so AI assistants and MCP-compatible editors can convert SQL schemas to ERDN and generate diagrams — no GUI, no clipboard, no copy-paste. + +## Installation + +Install the MCP server binary with `go install`: + +```sh +go install github.com/headercat/erdn-lang/cmd/erdn-mcp@latest +``` + +You need [Go 1.21](https://go.dev/dl/) or later. The binary is placed in `$GOPATH/bin` (usually `$HOME/go/bin`). Make sure that directory is on your `PATH`. + +## Client Configuration + +### Using the installed binary + +Add the following to your MCP client's configuration file (e.g. `claude_desktop_config.json`, `.cursor/mcp.json`, or VS Code's `settings.json`): + +```json +{ + "mcpServers": { + "erdn-lang": { + "type": "stdio", + "command": "erdn-mcp" + } + } +} +``` + +### Without installing (run directly from the module proxy) + +If you prefer not to install a binary, you can run the server on demand via `go run`. No local clone is needed — Go fetches the package automatically: + +```json +{ + "mcpServers": { + "erdn-lang": { + "type": "stdio", + "command": "go", + "args": ["run", "github.com/headercat/erdn-lang/cmd/erdn-mcp@latest"] + } + } +} +``` + +### Auto-discovery + +The repository root contains a ready-to-use `.mcp.json` file. MCP clients that support auto-discovery (such as recent versions of Claude Desktop and Cursor) will pick it up automatically when you open the repository folder. + +## Available Tools + +### `convert_sql_to_erdn` + +Converts SQL `CREATE TABLE` and `FOREIGN KEY` statements into ERDN source text. + +**Input** + +| Parameter | Type | Description | +|-----------|--------|-------------------------------------| +| `sql` | string | One or more SQL DDL statements | + +**Output** — The equivalent `.erdn` schema as a string. + +**Example prompt** + +> Convert this SQL schema to ERDN: +> +> ```sql +> CREATE TABLE users (id BIGINT PRIMARY KEY, username VARCHAR(64) NOT NULL); +> CREATE TABLE posts (id BIGINT PRIMARY KEY, author_id BIGINT, FOREIGN KEY (author_id) REFERENCES users(id)); +> ``` + +--- + +### `generate_svg_from_erdn` + +Validates ERDN source and returns the rendered SVG diagram as a string. + +**Input** + +| Parameter | Type | Description | +|-----------|--------|----------------------| +| `erdn` | string | ERDN schema source | + +**Output** — A self-contained SVG string that can be saved to a file or embedded in HTML. + +**Example prompt** + +> Generate an SVG diagram for this ERDN schema: +> +> ```erdn +> table users ( +> id bigint primary-key auto-increment +> username varchar(64) not-null indexed +> ) +> ``` + +## Running the Server Manually + +If you have a local clone of the repository you can also launch the server directly: + +```sh +go run ./cmd/erdn-mcp +``` + +The server communicates over **stdio** using JSON-RPC 2.0, which is the standard MCP transport.