diff --git a/v2/dbutils/ops/bobops/bobops.go b/v2/dbutils/ops/bobops/bobops.go index 4880831..df88ffc 100644 --- a/v2/dbutils/ops/bobops/bobops.go +++ b/v2/dbutils/ops/bobops/bobops.go @@ -19,6 +19,14 @@ func NewBobFilterMap(fields map[string]string) ops.FilterMap[bob.Mod[*dialect.Se type BobFilterer struct{} func (b *BobFilterer) ParseFilter(filter, alias string, op string, rawValue string, having bool) (bob.Mod[*dialect.SelectQuery], string, interface{}, error) { + return parseFilter(filter, alias, op, rawValue, having) +} + +func (b *BobFilterer) ParseSorting(sortList []string) (bob.Mod[*dialect.SelectQuery], error) { + return sm.OrderBy(strings.Join(sortList, ", ")), nil +} + +func parseFilter(filter, alias string, op string, rawValue string, having bool) (bob.Mod[*dialect.SelectQuery], string, interface{}, error) { if having { if ops.IsUnaryOp(op) { q := strings.ReplaceAll(filter, "{}", alias) @@ -56,7 +64,3 @@ func (b *BobFilterer) ParseFilter(filter, alias string, op string, rawValue stri q := strings.ReplaceAll(filter, "{}", alias) return sm.Where(psql.Raw(q, rawValue)), q, rawValue, nil } - -func (b *BobFilterer) ParseSorting(sortList []string) (bob.Mod[*dialect.SelectQuery], error) { - return sm.OrderBy(strings.Join(sortList, ", ")), nil -} diff --git a/v2/dbutils/ops/gen/README.md b/v2/dbutils/ops/gen/README.md new file mode 100644 index 0000000..477929c --- /dev/null +++ b/v2/dbutils/ops/gen/README.md @@ -0,0 +1,90 @@ +# Filter Code Generator + +This package provides automatic code generation for database filter methods based on struct field comments. + +## Overview + +Instead of manually creating `FilterMap` instances and calling `AddFilters` for each API filter, you can now: + +1. Define your endpoint payload struct with special `db:filter` comments +2. Run `go generate` to automatically create `AddFilters` methods +3. Use the generated methods in your handlers + +## Usage + +### 1. Annotate your structs + +Add `db:filter` comments to fields that should be filterable: + +```go +type ListDCRsRequest struct { + // db:filter bob_gen.ColumnNames.DCRS.Type + Type string `query:"type"` + // db:filter bob_gen.ColumnNames.DCRS.Status + Status string `query:"status"` + // db:filter bob_gen.ColumnNames.DCRS.CreatedBy + CreatedBy *string `query:"created_by"` + // db:filter bob_gen.ColumnNames.DCRS.Tags + Tags []string `query:"tags"` + + // Regular fields without filter comments are ignored + Limit int `query:"limit"` + Offset int `query:"offset"` +} +``` +Of course, this is assuming Huma. There is no support for Goa, sorry. + + +### 2. Add go generate directive + +Add this line to the top of your model files (it will also work in main.go, only a bit slower): + +```go +//go:generate go run github.com/top-solution/go-libs/v2/dbutils/ops/gen/cmd bob . +``` + +Or use the command directly: + +```bash +# Scan all folders inside ., generate bob filters +go run github.com/top-solution/go-libs/v2/dbutils/ops/gen/cmd bob . + +# Scan specific package, generate boiler filters +go run github.com/top-solution/go-libs/v2/dbutils/ops/gen/cmd boiler path/to/specific/packagh +``` + +### 3. Run go generate + +```bash +go generate ./... +``` + +Using ./.. will make sure it's going to also run //go:generate directive inside your model files. + +### 4. Use the generated methods + +The generator creates an `AddFilters` method for each annotated struct: + +```go +func (r *ListDCRsRequest) AddFilters(q *[]bob.Mod[*dialect.SelectQuery]) error +``` + +Use it in your handlers: + +```go +func ListDCRsHandler(ctx context.Context, req *ListDCRsRequest) (*ListDCRsResponse, error) { + var query []bob.Mod[*dialect.SelectQuery] + + // Automatically add filters based on request fields + if err := req.AddFilters(&query); err != nil { + return nil, err + } + + // Add other query modifications + query = append(query, sm.Limit(req.Limit)) + + // Execute query + dcrs, err := models.DCRS(query...).All(ctx, db) + // ... +} +``` diff --git a/v2/dbutils/ops/gen/cmd/main.go b/v2/dbutils/ops/gen/cmd/main.go new file mode 100644 index 0000000..57b7c8d --- /dev/null +++ b/v2/dbutils/ops/gen/cmd/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/top-solution/go-libs/v2/dbutils/ops/gen" +) + +func main() { + if len(os.Args) < 3 { + log.Fatal("Usage: gen ") + } + + filterType := os.Args[1] + rootPath := os.Args[2] + + // Convert relative path to absolute for better handling + absRootPath, err := filepath.Abs(rootPath) + if err != nil { + log.Fatalf("Failed to get absolute path for %s: %v", rootPath, err) + } + + fmt.Printf("Scanning directory: %s\n", absRootPath) + + // Walk through all directories under the root path + err = filepath.Walk(absRootPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip if not a directory + if !info.IsDir() { + return nil + } + + // Skip hidden directories and vendor directories, but allow the root directory even if it starts with "." + if path != absRootPath && (strings.HasPrefix(info.Name(), ".") || info.Name() == "vendor") { + return filepath.SkipDir + } + + // Check if this directory contains Go files (excluding test and generated files) + hasGoFiles, err := hasRelevantGoFiles(path) + if err != nil { + return err + } + + if !hasGoFiles { + return nil + } + + // Get package name from directory name + packageName := filepath.Base(path) + + // Handle special case where the directory is "." or the root + if packageName == "." || path == absRootPath { + // Try to get package name from go.mod or use directory name + if wd, err := os.Getwd(); err == nil { + packageName = filepath.Base(wd) + } + } + + // Create generator and process the package + generator := gen.NewGenerator(packageName, path, filterType) + if err := generator.GenerateFromPackage(); err != nil { + log.Printf("Warning: Failed to generate filters for package %s: %v", path, err) + return nil // Continue processing other packages + } + + return nil + }) + + if err != nil { + log.Fatalf("Failed to walk directory tree: %v", err) + } + + fmt.Println("Filter generation completed.") +} + +// hasRelevantGoFiles checks if a directory contains Go files that are not test files or generated files +func hasRelevantGoFiles(dir string) (bool, error) { + files, err := filepath.Glob(filepath.Join(dir, "*.go")) + if err != nil { + return false, err + } + + for _, file := range files { + filename := filepath.Base(file) + // Skip test files and generated files + if !strings.HasSuffix(filename, "_test.go") && !strings.Contains(filename, "_gen.go") && !strings.Contains(filename, ".gen.go") { + return true, nil + } + } + + return false, nil +} diff --git a/v2/dbutils/ops/gen/generator.go b/v2/dbutils/ops/gen/generator.go new file mode 100644 index 0000000..af5c58a --- /dev/null +++ b/v2/dbutils/ops/gen/generator.go @@ -0,0 +1,437 @@ +package gen + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "regexp" + "strings" + "text/template" +) + +// FilterField represents a field that should have filter generation +type FilterField struct { + Name string // Field name in the struct + Column string // Database column name from comment + QueryParam string // Query parameter name from struct tag + Type string // Field type +} + +// StructInfo contains information about a struct that needs filter generation +type StructInfo struct { + Name string + Package string + Fields []FilterField + ReceiverName string + Imports []string // Additional imports specified in comments + SortField string // The field to sort by, if specified +} + +// Generator handles the code generation for filter methods +type Generator struct { + packageName string + packageDir string + filterType string +} + +// NewGenerator creates a new generator instance +func NewGenerator(packageName, packageDir, filterType string) *Generator { + return &Generator{ + packageName: packageName, + packageDir: packageDir, + filterType: filterType, + } +} + +// GenerateFromPackage scans all Go files in the package directory and generates filter methods +func (g *Generator) GenerateFromPackage() error { + files, err := filepath.Glob(filepath.Join(g.packageDir, "*.go")) + if err != nil { + return fmt.Errorf("failed to find Go files: %w", err) + } + + for _, file := range files { + // Skip generated files and test files + if strings.HasSuffix(file, "_test.go") || strings.Contains(file, "_gen.go") { + continue + } + + if err := g.generateFromFile(file); err != nil { + return fmt.Errorf("failed to process file %s: %w", file, err) + } + } + + return nil +} + +// GenerateFromFile parses a Go file and generates filter methods for structs with db:filter comments +func (g *Generator) GenerateFromFile(filename string) error { + return g.generateFromFile(filename) +} + +// generateFromFile is the internal implementation for processing a single file +func (g *Generator) generateFromFile(filename string) error { + structs, err := g.parseFile(filename) + if err != nil { + return fmt.Errorf("failed to parse file %s: %w", filename, err) + } + + if len(structs) == 0 { + return nil // No structs with filter comments found + } + + var structNames []string + for _, s := range structs { + structNames = append(structNames, s.Name) + } + + fmt.Printf("Processing file %s (%v)\n", filename, strings.Join(structNames, ", ")) + + // Generate output filename: file.go -> file_filters.gen.go + outputFile := g.getOutputFilename(filename) + return g.generateCode(structs, outputFile) +} + +// getOutputFilename generates the output filename based on the input filename +func (g *Generator) getOutputFilename(inputFile string) string { + dir := filepath.Dir(inputFile) + base := filepath.Base(inputFile) + ext := filepath.Ext(base) + name := base[:len(base)-len(ext)] + return filepath.Join(dir, name+"_filters.gen.go") +} + +// parseFile parses a Go file and extracts struct information +func (g *Generator) parseFile(filename string) ([]StructInfo, error) { + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + return nil, err + } + + var structs []StructInfo + + // Create a map of positions to comments for easier lookup + commentMap := make(map[token.Pos]*ast.CommentGroup) + for _, cg := range node.Comments { + commentMap[cg.Pos()] = cg + } + + ast.Inspect(node, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.GenDecl: + if x.Tok == token.TYPE { + for _, spec := range x.Specs { + if typeSpec, ok := spec.(*ast.TypeSpec); ok { + if structType, ok := typeSpec.Type.(*ast.StructType); ok { + // Check for struct-level db:filter comment in the GenDecl doc + hasFilter, imports, sortField := g.parseFilterComments(x.Doc) + if hasFilter { + structInfo := g.parseStruct(typeSpec.Name.Name, structType) + if len(structInfo.Fields) > 0 { + structInfo.Package = node.Name.Name + structInfo.Imports = imports + structInfo.SortField = sortField + structs = append(structs, structInfo) + } + } + } + } + } + } + } + return true + }) + + return structs, nil +} + +// parseFilterComments checks if the struct has a db:filter comment and extracts imports +func (g *Generator) parseFilterComments(doc *ast.CommentGroup) (bool, []string, string) { + if doc == nil { + return false, nil, "" + } + + structFilterRegex := regexp.MustCompile(`//\s*db:filter\s*$`) + importRegex := regexp.MustCompile(`//\s*db:filter\s+import\s+(.+)`) + sortRegex := regexp.MustCompile(`//\s*db:filter\s+sortField\s+(.+)`) + + hasFilter := false + var imports []string + var sortField string + + for _, comment := range doc.List { + if structFilterRegex.MatchString(comment.Text) { + hasFilter = true + } else if matches := importRegex.FindStringSubmatch(comment.Text); len(matches) > 1 { + importSpec := strings.TrimSpace(matches[1]) + imports = append(imports, g.parseImportSpec(importSpec)) + } else if matches := sortRegex.FindStringSubmatch(comment.Text); len(matches) > 1 { + sortField = strings.TrimSpace(matches[1]) + } + + } + + return hasFilter, imports, sortField +} + +// parseImportSpec parses import specifications with optional aliases +// Supports formats like: +// - "package" +// - package +// - alias "package" +// - alias package +func (g *Generator) parseImportSpec(spec string) string { + spec = strings.TrimSpace(spec) + + // Check if it contains a space (indicating an alias) + parts := strings.Fields(spec) + + if len(parts) == 1 { + // Single part - just a package path + pkg := strings.Trim(parts[0], `"`) + return `"` + pkg + `"` + } else if len(parts) == 2 { + // Two parts - alias and package + alias := parts[0] + pkg := strings.Trim(parts[1], `"`) + return alias + ` "` + pkg + `"` + } + + // Fallback - return as is with quotes if not already quoted + if !strings.HasPrefix(spec, `"`) { + return `"` + spec + `"` + } + return spec +} + +var filterCommentRegex = regexp.MustCompile(`//\s*db:filter\s+(.+)`) + +// parseStruct extracts filter field information from a struct +func (g *Generator) parseStruct(name string, structType *ast.StructType) StructInfo { + info := StructInfo{ + Name: name, + ReceiverName: strings.ToLower(name[:1]), + Fields: []FilterField{}, + } + + for _, field := range structType.Fields.List { + if field.Doc == nil { + continue + } + + // Check for db:filter comment + var column string + for _, comment := range field.Doc.List { + matches := filterCommentRegex.FindStringSubmatch(comment.Text) + if len(matches) > 1 { + column = strings.TrimSpace(matches[1]) + break + } + } + + if column == "" { + continue + } + + // Extract field information + for _, fieldName := range field.Names { + fieldType := g.getTypeString(field.Type) + + // Extract query tag value if available + queryParam := fieldName.Name // Default to field name + if field.Tag != nil { + if tag := field.Tag.Value; tag != "" { + // Parse the struct tag to extract the "query" tag value + if queryValue := g.extractQueryTag(tag); queryValue != "" { + queryParam = queryValue + } + } + } + + info.Fields = append(info.Fields, FilterField{ + Name: fieldName.Name, + Column: column, + Type: fieldType, + QueryParam: queryParam, + }) + } + } + + return info +} + +// getTypeString converts an ast.Expr to a string representation +func (g *Generator) getTypeString(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + return "*" + g.getTypeString(t.X) + case *ast.SelectorExpr: + return g.getTypeString(t.X) + "." + t.Sel.Name + case *ast.ArrayType: + return "[]" + g.getTypeString(t.Elt) + default: + return "interface{}" + } +} + +// extractQueryTag extracts the value from the "query" struct tag +func (g *Generator) extractQueryTag(tag string) string { + // Remove backticks from the tag + tag = strings.Trim(tag, "`") + + // Look for query:"value" pattern + queryRegex := regexp.MustCompile(`query:"([^"]*)"`) + matches := queryRegex.FindStringSubmatch(tag) + if len(matches) > 1 { + return matches[1] + } + + return "" +} + +// generateCode generates the filter methods code +func (g *Generator) generateCode(structs []StructInfo, outputFile string) error { + // Skip file creation if no structs + if len(structs) == 0 { + return nil + } + + // Only generate for supported filter types + if g.filterType != "bob" && g.filterType != "boiler" { + return nil + } + + tmpl := template.Must(template.New("filters").Parse(codeTemplate)) + + file, err := os.Create(outputFile) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer file.Close() + + // Collect all unique additional imports from all structs + importSet := make(map[string]bool) + for _, s := range structs { + for _, imp := range s.Imports { + importSet[imp] = true + } + } + + var additionalImports []string + for imp := range importSet { + additionalImports = append(additionalImports, imp) + } + + // Check if any struct has sorting + hasSortingStructs := false + for _, s := range structs { + if s.SortField != "" { + hasSortingStructs = true + break + } + } + + data := struct { + FilterType string + Package string + Structs []StructInfo + AdditionalImports []string + HasSortingStructs bool + }{ + FilterType: g.filterType, + Package: g.packageName, + Structs: structs, + AdditionalImports: additionalImports, + HasSortingStructs: hasSortingStructs, + } + + return tmpl.Execute(file, data) +} + +const codeTemplate = `// Code generated by go-libs/v2/dbutils/ops/gen/cmd. DO NOT EDIT. + +package {{.Package}} + +import ( + "github.com/top-solution/go-libs/v2/dbutils/ops" + {{if eq .FilterType "bob"}}"github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/dialect/psql/dialect" + "github.com/top-solution/go-libs/v2/dbutils/ops/bobops" + {{ else if eq .FilterType "boiler"}}"github.com/top-solution/go-libs/v2/dbutils/ops/boilerops" + "github.com/volatiletech/sqlboiler/v4/queries/qm" + {{end}} +{{range .AdditionalImports}} {{.}} +{{end}}){{$lib := .FilterType}} +{{range .Structs}}{{$receiver := .ReceiverName}}{{$structName := .Name}} +// {{.Name}}ColumnsMap is a FilterMap mapping filter names to DB columns +// DO NOT EDIT: This var is generated by go-libs/v2/dbutils/ops/gen/cmd +var {{.Name}}ColumnsMap = {{if eq $lib "bob"}}bobops.NewBobFilterMap{{else if eq $lib "boiler"}}boilerops.NewBoilerFilterMap{{end}}(map[string]string{ + {{range .Fields}}"{{.QueryParam}}": {{.Column}},{{end}} +}) +// AddFilters adds database filters based on the struct fields with db:filter comments +// DO NOT EDIT: This func is generated by go-libs/v2/dbutils/ops/gen/cmd +func ({{.ReceiverName}} *{{.Name}}) AddFilters(q {{if eq $lib "bob"}}*[]bob.Mod[*dialect.SelectQuery]{{else if eq $lib "boiler"}}*[]qm.QueryMod{{end}}) error { + {{if eq $lib "bob"}}var qmods []bob.Mod[*dialect.SelectQuery] + {{else if eq $lib "boiler"}} + var qmods []qm.QueryMod{{end}} + {{range .Fields}}{{if eq .Type "string"}}if {{$receiver}}.{{.Name}} != "" { + op, cond, rawValue, err := ops.CurrentWhereFilters().Parse({{$receiver}}.{{.Name}}) + if err != nil { + return err + } + + qmod, _, _, err := {{$structName}}ColumnsMap.Filterer.ParseFilter(cond, {{.Column}}, op, rawValue, false) + if err != nil { + return err + } + qmods = append(qmods, qmod) + }{{else if eq .Type "*string"}} + if {{$receiver}}.{{.Name}} != nil && *{{$receiver}}.{{.Name}} != "" { + op, cond, rawValue, err := ops.CurrentWhereFilters().Parse(*{{$receiver}}.{{.Name}}) + if err != nil { + return err + } + + qmod, _, _, err := {{$structName}}ColumnsMap.Filterer.ParseFilter(cond, {{.Column}}, op, rawValue, false) + if err != nil { + return err + } + qmods = append(qmods, qmod) + }{{else if eq .Type "[]string"}} + if len({{$receiver}}.{{.Name}}) > 0 { + for _, v := range {{$receiver}}.{{.Name}} { + op, cond, rawValue, err := ops.CurrentWhereFilters().Parse(v) + if err != nil { + return err + } + + qmod, _, _, err := {{$structName}}ColumnsMap.Filterer.ParseFilter(cond, {{.Column}}, op, rawValue, false) + if err != nil { + return err + } + qmods = append(qmods, qmod) + } + } + {{else}} + // TODO: Add support for {{.Type}} type for field {{.Name}} + {{end}} + {{end}} + + *q = append(*q, qmods...) + + return nil +} +{{if ne .SortField ""}} +// AddSorting adds the result of ParseSorting to a given query +func ({{.ReceiverName}} *{{.Name}}) AddSorting(query {{if eq $lib "bob"}}*[]bob.Mod[*dialect.SelectQuery]{{else if eq $lib "boiler"}}*[]qm.QueryMod{{end}}) error { + return {{$structName}}ColumnsMap.AddSorting(query, {{$receiver}}.{{.SortField}}) +} +{{end}} +{{end}} +` diff --git a/v2/dbutils/ops/gen/generator_test.go b/v2/dbutils/ops/gen/generator_test.go new file mode 100644 index 0000000..895575a --- /dev/null +++ b/v2/dbutils/ops/gen/generator_test.go @@ -0,0 +1,667 @@ +package gen + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerator(t *testing.T) { + // Create a temporary test file + testContent := `package testpkg + +// db:filter +type ListDCRsRequest struct { + // db:filter bob_gen.ColumnNames.DCRS.Type + Type string ` + "`query:\"type\"`" + ` + // db:filter bob_gen.ColumnNames.DCRS.Status + Status string ` + "`query:\"status\"`" + ` + // Regular field without filter comment + Limit int ` + "`query:\"limit\"`" + ` +} + +// db:filter +type AnotherRequest struct { + // db:filter bob_gen.ColumnNames.Users.Name + Name *string ` + "`query:\"name\"`" + ` + // db:filter bob_gen.ColumnNames.Users.Tags + Tags []string ` + "`query:\"tags\"`" + ` +} + +type NoFilterRequest struct { + // db:filter bob_gen.ColumnNames.NoFilter.Field + Field string ` + "`query:\"field\"`" + ` +}` + + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + err := os.WriteFile(testFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("testpkg", tmpDir, "bob") + structs, err := generator.parseFile(testFile) + require.NoError(t, err) + + // Should find 2 structs with filter fields + assert.Len(t, structs, 2) + + // Check first struct + dcrsStruct := structs[0] + assert.Equal(t, "ListDCRsRequest", dcrsStruct.Name) + assert.Equal(t, "testpkg", dcrsStruct.Package) + assert.Equal(t, "l", dcrsStruct.ReceiverName) + assert.Len(t, dcrsStruct.Fields, 2) + + // Check fields + typeField := dcrsStruct.Fields[0] + assert.Equal(t, "Type", typeField.Name) + assert.Equal(t, "bob_gen.ColumnNames.DCRS.Type", typeField.Column) + assert.Equal(t, "string", typeField.Type) + + statusField := dcrsStruct.Fields[1] + assert.Equal(t, "Status", statusField.Name) + assert.Equal(t, "bob_gen.ColumnNames.DCRS.Status", statusField.Column) + assert.Equal(t, "string", statusField.Type) + + // Check second struct + usersStruct := structs[1] + assert.Equal(t, "AnotherRequest", usersStruct.Name) + assert.Equal(t, "a", usersStruct.ReceiverName) + assert.Len(t, usersStruct.Fields, 2) + + nameField := usersStruct.Fields[0] + assert.Equal(t, "Name", nameField.Name) + assert.Equal(t, "bob_gen.ColumnNames.Users.Name", nameField.Column) + assert.Equal(t, "*string", nameField.Type) + + tagsField := usersStruct.Fields[1] + assert.Equal(t, "Tags", tagsField.Name) + assert.Equal(t, "bob_gen.ColumnNames.Users.Tags", tagsField.Column) + assert.Equal(t, "[]string", tagsField.Type) +} + +func TestGenerator_GetTypeString(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple type", + input: "string", + expected: "string", + }, + { + name: "pointer type", + input: "*string", + expected: "*string", + }, + { + name: "slice type", + input: "[]string", + expected: "[]string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This test is more conceptual since getTypeString works with AST nodes + // The actual type parsing is tested through the integration tests + assert.NotEmpty(t, tt.expected) + }) + } +} + +func TestGenerator_NoFilterStructs(t *testing.T) { + testContent := `package testpkg + +type SimpleRequest struct { + Name string ` + "`query:\"name\"`" + ` + Age int ` + "`query:\"age\"`" + ` +}` + + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + err := os.WriteFile(testFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("testpkg", tmpDir, "bob") + + err = generator.GenerateFromFile(testFile) + require.NoError(t, err) + + // Output file should not be created when no filter structs are found + outputFile := filepath.Join(tmpDir, "test_filters.gen.go") + _, err = os.Stat(outputFile) + assert.True(t, os.IsNotExist(err)) +} + +func TestGenerator_GenerateFromPackage(t *testing.T) { + tmpDir := t.TempDir() + + // Create multiple Go files with filter structs + file1Content := `package requests + +// db:filter +type ListUsersRequest struct { + // db:filter bob_gen.ColumnNames.Users.Name + Name string ` + "`query:\"name\"`" + ` +}` + + file2Content := `package requests + +// db:filter +type ListOrdersRequest struct { + // db:filter bob_gen.ColumnNames.Orders.Status + Status string ` + "`query:\"status\"`" + ` +}` + + // Create a file without filter structs + file3Content := `package requests + +type SimpleRequest struct { + Field string ` + "`query:\"field\"`" + ` +}` + + // Create a test file (should be ignored) + testFileContent := `package requests + +// db:filter +type TestStruct struct { + // db:filter bob_gen.ColumnNames.Test.Field + Field string ` + "`query:\"field\"`" + ` +}` + + // Create a generated file (should be ignored) + genFileContent := `package requests + +// db:filter +type GenStruct struct { + // db:filter bob_gen.ColumnNames.Gen.Field + Field string ` + "`query:\"field\"`" + ` +}` + + // Write all files + err := os.WriteFile(filepath.Join(tmpDir, "users.go"), []byte(file1Content), 0644) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "orders.go"), []byte(file2Content), 0644) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "simple.go"), []byte(file3Content), 0644) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "test_test.go"), []byte(testFileContent), 0644) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "existing_gen.go"), []byte(genFileContent), 0644) + require.NoError(t, err) + + // Generate filters for the entire package + generator := NewGenerator("requests", tmpDir, "bob") + err = generator.GenerateFromPackage() + require.NoError(t, err) + + // Check that filter files were created for files with filter structs + usersFilterFile := filepath.Join(tmpDir, "users_filters.gen.go") + _, err = os.Stat(usersFilterFile) + require.NoError(t, err) + + ordersFilterFile := filepath.Join(tmpDir, "orders_filters.gen.go") + _, err = os.Stat(ordersFilterFile) + require.NoError(t, err) + + // Check that no filter file was created for simple.go (no filter structs) + simpleFilterFile := filepath.Join(tmpDir, "simple_filters.gen.go") + _, err = os.Stat(simpleFilterFile) + assert.True(t, os.IsNotExist(err)) + + // Check that no filter file was created for test file + testFilterFile := filepath.Join(tmpDir, "test_test_filters.gen.go") + _, err = os.Stat(testFilterFile) + assert.True(t, os.IsNotExist(err)) + + // Check that no filter file was created for existing generated file + existingGenFilterFile := filepath.Join(tmpDir, "existing_gen_filters.gen.go") + _, err = os.Stat(existingGenFilterFile) + assert.True(t, os.IsNotExist(err)) + + // Verify content of generated files + usersGenerated, err := os.ReadFile(usersFilterFile) + require.NoError(t, err) + assert.Contains(t, string(usersGenerated), "func (l *ListUsersRequest) AddFilters") + + ordersGenerated, err := os.ReadFile(ordersFilterFile) + require.NoError(t, err) + assert.Contains(t, string(ordersGenerated), "func (l *ListOrdersRequest) AddFilters") +} + +func TestGenerator_GetOutputFilename(t *testing.T) { + generator := NewGenerator("test", ".", "bob") + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple file", + input: "requests.go", + expected: "requests_filters.gen.go", + }, + { + name: "file with path", + input: "/path/to/requests.go", + expected: "/path/to/requests_filters.gen.go", + }, + { + name: "file with multiple dots", + input: "my.requests.go", + expected: "my.requests_filters.gen.go", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := generator.getOutputFilename(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGenerator_FilterCommentParsing(t *testing.T) { + testContent := `package testpkg + +// db:filter +type TestRequest struct { + // db:filter simple_column + Field1 string + // db:filter spaced_column + Field2 string + //db:filter no_space_column + Field3 string + // db:filter "quoted_column" + Field4 string + // Some other comment + Field5 string + // db:filter complex.column.name + Field6 string +}` + + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + err := os.WriteFile(testFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("testpkg", tmpDir, "bob") + structs, err := generator.parseFile(testFile) + require.NoError(t, err) + + require.Len(t, structs, 1) + testStruct := structs[0] + + // Should have 5 fields with filter comments (Field5 doesn't have db:filter) + assert.Len(t, testStruct.Fields, 5) + + expectedColumns := []string{ + "simple_column", + "spaced_column", + "no_space_column", + "\"quoted_column\"", + "complex.column.name", + } + + for i, field := range testStruct.Fields { + assert.Equal(t, expectedColumns[i], field.Column) + assert.Equal(t, "string", field.Type) + } +} + +func TestGenerator_StructFilterComment(t *testing.T) { + tests := []struct { + name string + content string + expectFound bool + }{ + { + name: "has db:filter comment", + content: `package testpkg + +// db:filter +type TestRequest struct { + // db:filter column_name + Field string +}`, + expectFound: true, + }, + { + name: "no struct filter comment", + content: `package testpkg + +type TestRequest struct { + // db:filter column_name + Field string +}`, + expectFound: false, + }, + { + name: "other comment", + content: `package testpkg + +// Some other comment +type TestRequest struct { + // db:filter column_name + Field string +}`, + expectFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + err := os.WriteFile(testFile, []byte(tt.content), 0644) + require.NoError(t, err) + + generator := NewGenerator("testpkg", tmpDir, "bob") + structs, err := generator.parseFile(testFile) + require.NoError(t, err) + + if tt.expectFound { + require.Len(t, structs, 1) + } else { + assert.Len(t, structs, 0) + } + }) + } +} +func TestGenerator_ImportComments(t *testing.T) { + testContent := `package testpkg + +// db:filter +// db:filter import "fmt" +// db:filter import "time" +// db:filter import github.com/example/pkg +// db:filter import json "encoding/json" +// db:filter import ctx "context" +type TestRequest struct { + // db:filter column_name + Field string +}` + + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + err := os.WriteFile(testFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("testpkg", tmpDir, "bob") + structs, err := generator.parseFile(testFile) + require.NoError(t, err) + + require.Len(t, structs, 1) + testStruct := structs[0] + + // Should have 5 imports + assert.Len(t, testStruct.Imports, 5) + assert.Contains(t, testStruct.Imports, `"fmt"`) + assert.Contains(t, testStruct.Imports, `"time"`) + assert.Contains(t, testStruct.Imports, `"github.com/example/pkg"`) + assert.Contains(t, testStruct.Imports, `json "encoding/json"`) + assert.Contains(t, testStruct.Imports, `ctx "context"`) +} + +func TestGenerator_GenerateWithImports(t *testing.T) { + tmpDir := t.TempDir() + + // Create test input file with imports + testContent := `package requests + +// db:filter +// db:filter import "fmt" +// db:filter import "encoding/json" +type ListUsersRequest struct { + // db:filter bob_gen.ColumnNames.Users.Name + Name string ` + "`query:\"name\"`" + ` +}` + + inputFile := filepath.Join(tmpDir, "requests.go") + err := os.WriteFile(inputFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("requests", tmpDir, "bob") + + err = generator.GenerateFromFile(inputFile) + require.NoError(t, err) + + // Check that output file was created + outputFile := filepath.Join(tmpDir, "requests_filters.gen.go") + _, err = os.Stat(outputFile) + require.NoError(t, err) + + // Read and verify generated content + generated, err := os.ReadFile(outputFile) + require.NoError(t, err) + + generatedStr := string(generated) + + // Check for expected imports + assert.Contains(t, generatedStr, `"fmt"`) + assert.Contains(t, generatedStr, `"encoding/json"`) + assert.Contains(t, generatedStr, "package requests") + assert.Contains(t, generatedStr, "func (l *ListUsersRequest) AddFilters") +} +func TestGenerator_ParseImportSpec(t *testing.T) { + generator := NewGenerator("test", ".", "bob") + + tests := []struct { + name string + input string + expected string + }{ + { + name: "quoted package", + input: `"fmt"`, + expected: `"fmt"`, + }, + { + name: "unquoted package", + input: `fmt`, + expected: `"fmt"`, + }, + { + name: "alias with quoted package", + input: `json "encoding/json"`, + expected: `json "encoding/json"`, + }, + { + name: "alias with unquoted package", + input: `ctx context`, + expected: `ctx "context"`, + }, + { + name: "complex package path", + input: `github.com/example/pkg`, + expected: `"github.com/example/pkg"`, + }, + { + name: "alias with complex package path", + input: `pkg github.com/example/pkg`, + expected: `pkg "github.com/example/pkg"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := generator.parseImportSpec(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} +func TestGenerator_SortField(t *testing.T) { + tests := []struct { + name string + content string + expectSort bool + sortField string + }{ + { + name: "struct with sortField comment", + content: `package testpkg + +// db:filter +// db:filter sortField Sort +type TestRequest struct { + // db:filter column_name + Field string + Sort []string ` + "`query:\"sort\"`" + ` +}`, + expectSort: true, + sortField: "Sort", + }, + { + name: "struct without sortField comment", + content: `package testpkg + +// db:filter +type TestRequest struct { + // db:filter column_name + Field string + Sort []string ` + "`query:\"sort\"`" + ` +}`, + expectSort: false, + sortField: "", + }, + { + name: "struct with different sortField name", + content: `package testpkg + +// db:filter +// db:filter sortField OrderBy +type TestRequest struct { + // db:filter column_name + Field string + OrderBy []string ` + "`query:\"order_by\"`" + ` +}`, + expectSort: true, + sortField: "OrderBy", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + err := os.WriteFile(testFile, []byte(tt.content), 0644) + require.NoError(t, err) + + generator := NewGenerator("testpkg", tmpDir, "bob") + structs, err := generator.parseFile(testFile) + require.NoError(t, err) + + require.Len(t, structs, 1) + if tt.expectSort { + assert.Equal(t, tt.sortField, structs[0].SortField) + assert.NotEmpty(t, structs[0].SortField) + } else { + assert.Empty(t, structs[0].SortField) + } + }) + } +} + +func TestGenerator_GenerateWithSorting(t *testing.T) { + tmpDir := t.TempDir() + + // Create test input file with sortField comment + testContent := `package requests + +// db:filter +// db:filter sortField Sort +type ListUsersRequest struct { + // db:filter bob_gen.ColumnNames.Users.Name + Name string ` + "`query:\"name\"`" + ` + Sort []string ` + "`query:\"sort\"`" + ` +}` + + inputFile := filepath.Join(tmpDir, "requests.go") + err := os.WriteFile(inputFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("requests", tmpDir, "bob") + + err = generator.GenerateFromFile(inputFile) + require.NoError(t, err) + + // Check that output file was created + outputFile := filepath.Join(tmpDir, "requests_filters.gen.go") + _, err = os.Stat(outputFile) + require.NoError(t, err) + + // Read and verify generated content + generated, err := os.ReadFile(outputFile) + require.NoError(t, err) + + generatedStr := string(generated) + + // Check for expected content + assert.Contains(t, generatedStr, "package requests") + assert.Contains(t, generatedStr, "func (l *ListUsersRequest) AddFilters") + assert.Contains(t, generatedStr, "func (l *ListUsersRequest) AddSorting") + assert.Contains(t, generatedStr, "ListUsersRequestColumnsMap.AddSorting(query, l.Sort)") + + // Check for proper imports structure + assert.Contains(t, generatedStr, `"github.com/top-solution/go-libs/v2/dbutils/ops"`) + assert.Contains(t, generatedStr, `"github.com/stephenafamo/bob"`) + assert.Contains(t, generatedStr, `"github.com/stephenafamo/bob/dialect/psql/dialect"`) + assert.Contains(t, generatedStr, `"github.com/top-solution/go-libs/v2/dbutils/ops/bobops"`) +} + +func TestGenerator_GenerateWithoutSorting(t *testing.T) { + tmpDir := t.TempDir() + + // Create test input file without sortField comment + testContent := `package requests + +// db:filter +type ListUsersRequest struct { + // db:filter bob_gen.ColumnNames.Users.Name + Name string ` + "`query:\"name\"`" + ` +}` + + inputFile := filepath.Join(tmpDir, "requests.go") + err := os.WriteFile(inputFile, []byte(testContent), 0644) + require.NoError(t, err) + + generator := NewGenerator("requests", tmpDir, "bob") + + err = generator.GenerateFromFile(inputFile) + require.NoError(t, err) + + // Check that output file was created + outputFile := filepath.Join(tmpDir, "requests_filters.gen.go") + _, err = os.Stat(outputFile) + require.NoError(t, err) + + // Read and verify generated content + generated, err := os.ReadFile(outputFile) + require.NoError(t, err) + + generatedStr := string(generated) + + // Check for expected content + assert.Contains(t, generatedStr, "package requests") + assert.Contains(t, generatedStr, "func (l *ListUsersRequest) AddFilters") + // Should NOT contain AddSorting function + assert.NotContains(t, generatedStr, "func (l *ListUsersRequest) AddSorting") + assert.NotContains(t, generatedStr, "AddSorting") + + // Check for proper imports structure (without errors) + assert.Contains(t, generatedStr, `"github.com/top-solution/go-libs/v2/dbutils/ops"`) + assert.Contains(t, generatedStr, `"github.com/stephenafamo/bob"`) + assert.Contains(t, generatedStr, `"github.com/stephenafamo/bob/dialect/psql/dialect"`) + assert.Contains(t, generatedStr, `"github.com/top-solution/go-libs/v2/dbutils/ops/bobops"`) +} diff --git a/v2/dbutils/ops/gen/tst/tst.go b/v2/dbutils/ops/gen/tst/tst.go new file mode 100644 index 0000000..55e6aec --- /dev/null +++ b/v2/dbutils/ops/gen/tst/tst.go @@ -0,0 +1,19 @@ +//go:generate go run ../cmd/main.go bob . +package tst + +type Sortable struct { + Sort []string +} + +// db:filter +// db:filter import "fmt" +// db:filter sortField Sort +type TestStruct struct { + Sortable + // db:filter "stuff" + Test string `query:"test"` + // db:filter fmt.Sprintf("heee") + Test2 *string `query:"test2"` + // db:filter "EEEI" + Test3 []string `query:"test3"` +} diff --git a/v2/dbutils/ops/gen/tst/tst_filters.gen.go b/v2/dbutils/ops/gen/tst/tst_filters.gen.go new file mode 100644 index 0000000..b387d06 --- /dev/null +++ b/v2/dbutils/ops/gen/tst/tst_filters.gen.go @@ -0,0 +1,77 @@ +// Code generated by go-libs/v2/dbutils/ops/gen/cmd. DO NOT EDIT. + +package tst + +import ( + "github.com/top-solution/go-libs/v2/dbutils/ops" + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/dialect/psql/dialect" + "github.com/top-solution/go-libs/v2/dbutils/ops/bobops" + + "fmt" +) + +// TestStructColumnsMap is a FilterMap mapping filter names to DB columns +// DO NOT EDIT: This file is generated by go-libs/v2/dbutils/ops/gen/cmd +var TestStructColumnsMap = bobops.NewBobFilterMap(map[string]string{ + "test": "stuff","test2": fmt.Sprintf("heee"),"test3": "EEEI", +}) +// AddFilters adds database filters based on the struct fields with db:filter comments +// DO NOT EDIT: This file is generated by go-libs/v2/dbutils/ops/gen/cmd +func (t *TestStruct) AddFilters(q *[]bob.Mod[*dialect.SelectQuery]) error { + var qmods []bob.Mod[*dialect.SelectQuery] + + if t.Test != "" { + op, cond, rawValue, err := ops.CurrentWhereFilters().Parse(t.Test) + if err != nil { + return err + } + + qmod, _, _, err := TestStructColumnsMap.Filterer.ParseFilter(cond, "stuff", op, rawValue, false) + if err != nil { + return err + } + qmods = append(qmods, qmod) + } + + if t.Test2 != nil && *t.Test2 != "" { + op, cond, rawValue, err := ops.CurrentWhereFilters().Parse(*t.Test2) + if err != nil { + return err + } + + qmod, _, _, err := TestStructColumnsMap.Filterer.ParseFilter(cond, fmt.Sprintf("heee"), op, rawValue, false) + if err != nil { + return err + } + qmods = append(qmods, qmod) + } + + if len(t.Test3) > 0 { + for _, v := range t.Test3 { + op, cond, rawValue, err := ops.CurrentWhereFilters().Parse(v) + if err != nil { + return err + } + + qmod, _, _, err := TestStructColumnsMap.Filterer.ParseFilter(cond, "EEEI", op, rawValue, false) + if err != nil { + return err + } + qmods = append(qmods, qmod) + } + } + + + + *q = append(*q, qmods...) + + return nil +} + +// AddSorting adds the result of ParseSorting to a given query +func (t *TestStruct) AddSorting(query *[]bob.Mod[*dialect.SelectQuery]) error { + return TestStructColumnsMap.AddSorting(query, t.Sort) +} + + diff --git a/v2/dbutils/ops/ops.go b/v2/dbutils/ops/ops.go index b379e3d..9aab55a 100644 --- a/v2/dbutils/ops/ops.go +++ b/v2/dbutils/ops/ops.go @@ -30,7 +30,7 @@ type Filterer[T any] interface { // FilterMap is a helper struct to parse filters into a slice of query mods // Query Mods can be from different query builders type FilterMap[T any] struct { - filterer Filterer[T] + Filterer Filterer[T] fields map[string]string } @@ -39,14 +39,14 @@ type FilterMap[T any] struct { // If you need to use this with bob, see bobops package func NewFilterMap[T any](fields map[string]string, f Filterer[T]) FilterMap[T] { return FilterMap[T]{ - filterer: f, + Filterer: f, fields: fields, } } // AddFilters parses the filters and adds them to the given list of query mods func (f FilterMap[T]) AddFilters(q *[]T, attribute string, filters ...string) error { - filter, _, _, _, err := parseFilters(f.filterer, f.fields, attribute, false, filters...) + filter, _, _, _, err := parseFilters(f.Filterer, f.fields, attribute, false, filters...) if err != nil { return fmt.Errorf("error parsing filters: %w", err) } @@ -66,7 +66,7 @@ func (f FilterMap[T]) AddHavingFilters(query *[]T, attribute string, data ...str // ParseFilters parses the filters and returns the query mods, raw queries, operators and values func (f FilterMap[T]) ParseFilters(attribute string, having bool, filters ...string) ([]T, []string, []string, []interface{}, error) { - return parseFilters(f.filterer, f.fields, attribute, having, filters...) + return parseFilters(f.Filterer, f.fields, attribute, having, filters...) } // ParseSorting generates an OrderBy QueryMod starting from a given list of user-inputted values and an attribute->column map @@ -87,7 +87,7 @@ func (f FilterMap[T]) ParseSorting(sort []string) (T, error) { } sortList = append(sortList, f.fields[elem]+direction) } - return f.filterer.ParseSorting(sortList) + return f.Filterer.ParseSorting(sortList) } // AddSorting adds the result of ParseSorting to a given query @@ -104,36 +104,22 @@ func (f FilterMap[T]) AddSorting(query *[]T, sort []string) (err error) { return nil } -func parseFilters[T any](filterer Filterer[T], f map[string]string, attribute string, having bool, filters ...string) ([]T, []string, []string, []interface{}, error) { +func parseFilters[T any](filterer Filterer[T], f map[string]string, attribute string, having bool, filters ...string) ([]T, []string, []string, []any, error) { var qmods []T var rawQueries []string var ops []string - var vals []interface{} + var vals []any if _, ok := f[attribute]; !ok { return nil, nil, nil, nil, fmt.Errorf("attribute %s not found", attribute) } - driverFilters := postgresWhereFilters - if dbutils.CurrentDriver == dbutils.MSSQLDriver { - driverFilters = msSQLWhereFilters - } - for _, filter := range filters { - spl := strings.SplitN(filter, ":", 2) - op := spl[0] - rawValue := "" - if len(spl) < 2 { - if !IsUnaryOp(op) { - return nil, nil, nil, nil, fmt.Errorf("operation %s is not valid", op) - } - } else { - rawValue = spl[1] - } - if _, ok := driverFilters[op]; !ok { - return nil, nil, nil, nil, fmt.Errorf("operation %s is not implemented", op) + op, cond, rawValue, err := CurrentWhereFilters().Parse(filter) + if err != nil { + return nil, nil, nil, nil, err } - qmod, raw, val, err := filterer.ParseFilter(driverFilters[op], f[attribute], op, rawValue, having) + qmod, raw, val, err := filterer.ParseFilter(cond, f[attribute], op, rawValue, having) if err != nil { return nil, nil, nil, nil, err } @@ -145,7 +131,28 @@ func parseFilters[T any](filterer Filterer[T], f map[string]string, attribute st return qmods, rawQueries, ops, vals, nil } -var msSQLWhereFilters = map[string]string{ +type WhereFilters map[string]string + +func (w WhereFilters) Parse(filter string) (op string, cond string, val string, err error) { + spl := strings.SplitN(filter, ":", 2) + op = spl[0] + rawValue := "" + if len(spl) < 2 { + if !IsUnaryOp(op) { + return "", "", "", fmt.Errorf("operation %s is not valid", op) + } + } else { + rawValue = spl[1] + } + if _, ok := w[op]; !ok { + return "", "", "", fmt.Errorf("operation %s is not implemented", op) + } + + return op, w[op], rawValue, nil + +} + +var msSQLWhereFilters = WhereFilters{ "eq": "{} = ?", "neq": "{} != ?", "like": "{} LIKE ? ESCAPE '_'", @@ -162,7 +169,7 @@ var msSQLWhereFilters = map[string]string{ "isNotEmpty": "coalesce({},'') != ''", } -var postgresWhereFilters = map[string]string{ +var postgresWhereFilters = WhereFilters{ "eq": "{} = ?", "neq": "{} != ?", "like": "{} ILIKE ? ESCAPE '_'", @@ -178,3 +185,11 @@ var postgresWhereFilters = map[string]string{ "isEmpty": "coalesce({},'') = ''", "isNotEmpty": "coalesce({},'') != ''", } + +func CurrentWhereFilters() WhereFilters { + // FIXME: This can't work if using two connections with different drivers + if dbutils.CurrentDriver == dbutils.MSSQLDriver { + return msSQLWhereFilters + } + return postgresWhereFilters +}