diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3b99c17..fae1fc8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -2,23 +2,23 @@ name: Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: test: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: "1.20" + go-version: "1.22" - name: Go linter - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v8 with: - version: v1.53 + version: v2.1 - name: Run tests run: go test -v ./... diff --git a/.golangci.yml b/.golangci.yml index c42ddf9..c1e0a09 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,16 +1,10 @@ -run: - skip-files: - - _test.go +version: "2" + linters: - enable-all: true + default: all disable: - - maligned - - interfacer - - scopelint # deprecated - - golint # deprecated - dupl - funlen - - gomnd - lll - gochecknoglobals - varnamelen @@ -18,41 +12,43 @@ linters: - gomoddirectives - godox - gocyclo - - exhaustivestruct - exhaustruct - tagliatelle - wsl + # - wsl_v5 - forbidigo - makezero - depguard - - wrapcheck # will fix later + - wrapcheck - gocritic # does not like log.Fatal with defers - - gci # collides with other linters - godot # comments ending in periods... - cyclop # like gocyclo - gocognit # like cyclop - maintidx # another complexity one - - goerr113 # wants static errors for all non-formatted error returns - errname # base error naming - nilnil # inconsistent error returns - prealloc # wants to declaring vars with := struct{} - - ifshort # if blocks could be inlined - nlreturn # wants newline before return, break, and continue (dont want newline before break & continue) - exhaustive # some missing switch statements - - nestif # nested if blocks, complexity issues + - nestif # nested if blocks, complexity issues - forcetypeassert # missing some checks, lots of false errors - containedctx # some contexts embedded in structs - contextcheck # probably worth looking at - wastedassign # some interesting assignments, does not like declaring new vars like := "" - promlinter # some observability stuff - nonamedreturns - - nosnakecase # disbaled because of too many false positives due to protobuf types -linters-settings: - goheader: - values: - const: - COMPANY: Ctrl IQ, Inc - template: |- - SPDX-FileCopyrightText: Copyright (c) {{ YEAR-RANGE }}, {{ COMPANY }}. All rights reserved - SPDX-License-Identifier: Apache-2.0 + settings: + goheader: + values: + const: + COMPANY: CTRL IQ, Inc + template: |- + SPDX-FileCopyrightText: Copyright (c) {{ YEAR-RANGE }}, {{ COMPANY }}. All rights reserved + SPDX-License-Identifier: Apache-2.0 + + exclusions: + warn-unused: false + paths: + - _test.go + - testproto/* diff --git a/go.mod b/go.mod index 722d236..31c5626 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module go.ciq.dev/pika -go 1.20 +go 1.22 + +toolchain go1.24.4 require ( github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df @@ -13,7 +15,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 github.com/wk8/go-ordered-map/v2 v2.1.8 - google.golang.org/protobuf v1.31.0 + google.golang.org/protobuf v1.36.6 ) require ( diff --git a/go.sum b/go.sum index 1cfe37a..3cc5d72 100644 --- a/go.sum +++ b/go.sum @@ -13,9 +13,8 @@ github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlL github.com/gertd/go-pluralize v0.2.1/go.mod h1:rbYaKDbsXxmRfr8uygAEKhOWsjyrrqrkHVpZvoOp8zk= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= @@ -43,15 +42,14 @@ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+x github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/parser/gen.go b/parser/gen.go index 1fe3572..397d763 100644 --- a/parser/gen.go +++ b/parser/gen.go @@ -1,6 +1,8 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 -package parser +//go:generate antlr -Dlanguage=Go -o . Filter.g4 -//go:generate antlr4 -Dlanguage=Go -no-visitor -package parser Filter.g4 +// Package parser provides ANTLR-generated parsers for filter expressions. +// This package contains generated code for parsing AIP-160 compliant filter syntax. +package parser diff --git a/pika.go b/pika.go index a83b531..df198ce 100644 --- a/pika.go +++ b/pika.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -10,14 +10,20 @@ import ( ) var ( - PikaMetadataTableName = "PikaTableName" + // PikaMetadataTableName is the field name used to specify custom table names in struct tags. + PikaMetadataTableName = "PikaTableName" + // PikaMetadataDefaultOrderBy is the field name used to specify default ordering in struct tags. PikaMetadataDefaultOrderBy = "PikaDefaultOrderBy" - PikaMetadataFields = []string{ + // PikaMetadataFields contains all available metadata field names for Pika configuration. + PikaMetadataFields = []string{ PikaMetadataTableName, PikaMetadataDefaultOrderBy, } ) +// Q creates a new QuerySet for the given type T using the provided database connection. +// It automatically detects the database type and returns the appropriate QuerySet implementation. +// Currently supports PostgreSQL connections. Panics if an unsupported database type is provided. func Q[T any](x any) QuerySet[T] { if sql, ok := x.(*PostgreSQL); ok { return PSQLQuery[T](sql) @@ -199,6 +205,8 @@ type QuerySet[T any] interface { Transaction(ctx context.Context) (QuerySet[T], error) } +// NewArgs creates a new ordered map for storing named query arguments. +// This is a convenience function for creating argument maps that can be passed to QuerySet.Args(). func NewArgs() *orderedmap.OrderedMap[string, any] { return orderedmap.New[string, any]() } diff --git a/pika_aip_filter.go b/pika_aip_filter.go index abb116e..50c4ff6 100644 --- a/pika_aip_filter.go +++ b/pika_aip_filter.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -18,6 +18,21 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +// Static errors for err113 compliance +var ( + ErrInvalidSuffix = errors.New("invalid suffix for identifier") + ErrIdentifierNotAcceptable = errors.New("identifier is not acceptable") + ErrNestedExpressionsNotSupported = errors.New("nested expressions are not supported") + ErrCannotCombineMultipleValues = errors.New("cannot combine multiple values in subexpression") + ErrUnknownAliasType = errors.New("unknown alias type") + ErrTypeNotAccepted = errors.New("type is not accepted for identifier") + ErrValueNotAccepted = errors.New("value is not accepted for identifier") + ErrUnexpectedIdentifier = errors.New("unexpected identifier") + ErrMissingOperator = errors.New("missing operator") + ErrMissingIdentifier = errors.New("missing identifier") + ErrIdentifierNotAllowed = errors.New("identifier is not allowed") +) + // The goal of this AIP Filter extension is to be able to parse // grammar like the one used in AIP-160. // This extension uses Antlr generated parser to parse the @@ -27,10 +42,15 @@ import ( // The filters on the QuerySet are applied in the order they are // given. +// AIPFilter provides AIP-160 compliant filtering capabilities for database queries. +// It uses ANTLR-generated parsers to parse filter expressions and converts them +// into QuerySet operations that can be applied to database queries. type AIPFilter[T any] struct { QuerySet[T] } +// AIPFilterIdentifier configures how a specific identifier (field) should be handled +// during AIP-160 filter parsing, including type validation, value aliases, and column mapping. type AIPFilterIdentifier struct { // Value aliases are used to map a value to a different value. // Mostly useful for enums, where for example the values @@ -60,6 +80,8 @@ type AIPFilterIdentifier struct { IsRepeated bool } +// AIPFilterOptions provides configuration for AIP-160 filter parsing, +// including identifier validation, type checking, and field mapping. type AIPFilterOptions struct { // Identifiers are additional configuration for specific identifiers. Identifiers map[string]AIPFilterIdentifier @@ -86,13 +108,15 @@ func (a AIPFilterOptions) verifyOrderBy(orderBy string) ([]string, error) { idents := strings.Split(fullIdentifier, " ") identifier := idents[0] sort := "asc" + if len(idents) > 1 { sort = strings.ToLower(idents[1]) // Check if the suffix is valid if sort != "asc" && sort != "desc" { - return nil, fmt.Errorf("invalid suffix %s for identifier %s", sort, identifier) + return nil, fmt.Errorf("%w: %s for identifier %s", ErrInvalidSuffix, sort, identifier) } } + prefix := "" if sort == "desc" { prefix = "-" @@ -100,7 +124,7 @@ func (a AIPFilterOptions) verifyOrderBy(orderBy string) ([]string, error) { // Verify that the identifier is acceptable if !contains(a.AcceptableIdentifiers, identifier) { - return nil, fmt.Errorf("identifier %s is not acceptable", identifier) + return nil, fmt.Errorf("%w: %s", ErrIdentifierNotAcceptable, identifier) } // Check if column name is defined @@ -116,6 +140,7 @@ func (a AIPFilterOptions) verifyOrderBy(orderBy string) ([]string, error) { return newOrderBy, nil } +// NewAIPFilter creates a new AIPFilter instance for the given type T. func NewAIPFilter[T any]() *AIPFilter[T] { return &AIPFilter[T]{} } @@ -144,6 +169,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt if options.Identifiers == nil { options.Identifiers = map[string]AIPFilterIdentifier{} } + for identifier, opts := range options.Identifiers { // Verify that accepted types is in antlrValues for _, acceptedType := range opts.AcceptedTypes { @@ -176,6 +202,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt args: orderedmap.New[string, any](), }, } + i := 0 for { activeState := states[i] @@ -190,6 +217,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt activeState.filterContent = append(activeState.filterContent, activeState.activeExpr) activeState.activeExpr = nil } + break } @@ -204,12 +232,12 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt // We currently don't support nested expressions like this // (a = 1 AND (b = 2 OR c = 3)) if !activeState.initParens && activeState.activeParens { - return nil, fmt.Errorf("nested expressions are not supported") + return nil, fmt.Errorf("%w", ErrNestedExpressionsNotSupported) } // Disallow combined expression for values if activeState.activeIdentifier != "" { - return nil, fmt.Errorf("cannot combine multiple values in subexpression") + return nil, fmt.Errorf("%w", ErrCannotCombineMultipleValues) } if activeState.activeParens && !activeState.initParens { @@ -228,32 +256,41 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt activeState.initParens = false activeState.forceNot = activeState.activeNot } + continue + // Right parenthesis. // Closing previous parenthesis. case parser.FilterLexerRPAREN: activeState.activeParens = false + // If whitespace, ignore. case parser.FilterLexerWHITESPACE: continue + // If OR, enable OR mode. case parser.FilterLexerOR: // If it's already active, we need to add a hint to force innerOr // If activeOr is true, we need to add a hint instead activeState.activeOperator += HintOr activeState.activeOr = true + continue + // If AND, disable OR mode. case parser.FilterLexerAND: // We support multiple hints, since AND is default, we just // need to force it if OR is activated activeState.activeOr = false activeState.activeOperator += HintAnd + continue + // If NOT, enable NOT mode. case parser.FilterLexerNOT, parser.FilterLexerMINUS: activeState.activeNot = true + continue } @@ -274,6 +311,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt // If operator, set the current operator. newVal := fmt.Sprintf("%s%s", x, activeState.activeOperator) activeState.activeOperator = newVal + continue } } else { @@ -408,7 +446,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt // If alias is a uint64, we need to set the type to uint64 activeState.activeValueType = parser.FilterLexerNUM_UINT } else { - return nil, fmt.Errorf("unknown alias type %T", alias) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, alias) } } @@ -423,7 +461,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt } } if !isOk { - return nil, fmt.Errorf("type %s is not accepted for identifier %s", lexer.SymbolicNames[activeState.activeValueType], activeState.activeIdentifier) + return nil, fmt.Errorf("%w: %s for identifier %s", ErrTypeNotAccepted, lexer.SymbolicNames[activeState.activeValueType], activeState.activeIdentifier) } } @@ -438,7 +476,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt } } if !isOk { - return nil, fmt.Errorf("value %v is not accepted for identifier %s", activeState.activeValue, activeState.activeIdentifier) + return nil, fmt.Errorf("%w: %v for identifier %s", ErrValueNotAccepted, activeState.activeValue, activeState.activeIdentifier) } } } @@ -451,22 +489,22 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt activeState.activeIdentifier = t.GetText() continue } - return nil, fmt.Errorf("unexpected identifier %s", t.GetText()) + return nil, fmt.Errorf("%w: %s", ErrUnexpectedIdentifier, t.GetText()) } if activeState.activeOperator != "" && activeState.activeIdentifier != "" && activeState.activeValue != nil { operator := activeState.activeOperator if operator == "" { - return nil, fmt.Errorf("missing operator") + return nil, fmt.Errorf("%w", ErrMissingOperator) } if activeState.activeIdentifier == "" { - return nil, fmt.Errorf("missing identifier") + return nil, fmt.Errorf("%w", ErrMissingIdentifier) } // Check if AcceptableIdentifiers are set, if so check if identifier is valid if len(options.AcceptableIdentifiers) > 0 { if !contains(options.AcceptableIdentifiers, activeState.activeIdentifier) { - return nil, fmt.Errorf("identifier %s is not allowed", activeState.activeIdentifier) + return nil, fmt.Errorf("%w: %s", ErrIdentifierNotAllowed, activeState.activeIdentifier) } } @@ -505,7 +543,7 @@ func (a *AIPFilter[T]) aip160(b QuerySet[T], filter string, options AIPFilterOpt suffix += strconv.Itoa(prefixCount) } argKey := fmt.Sprintf("%s%s", cleanKey(key), suffix) - value := fmt.Sprintf(":%s", argKey) + value := ":" + argKey // For AIP-160 purposes, if the operator has HintILike, then we need to // wrap the value in % to match wildcard diff --git a/pika_aip_filter_proto.go b/pika_aip_filter_proto.go index 2138c8f..6956f4e 100644 --- a/pika_aip_filter_proto.go +++ b/pika_aip_filter_proto.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -34,6 +34,8 @@ var ( } ) +// ProtoReflectOptions configures how protobuf message reflection is performed +// for generating AIP filter options. type ProtoReflectOptions struct { // Exclude is a list of field names to exclude from the filter // Uses proto name always, not JSON name @@ -51,7 +53,7 @@ func protoReflect(m proto.Message, opts ProtoReflectOptions) AIPFilterOptions { } fields := m.ProtoReflect().Descriptor().Fields() - for i := 0; i < fields.Len(); i++ { + for i := range fields.Len() { // Get field from message fd := fields.Get(i) @@ -109,10 +111,10 @@ func protoReflect(m proto.Message, opts ProtoReflectOptions) AIPFilterOptions { } case protoreflect.EnumKind: // Add all enum values as aliases - for i := 0; i < fd.Enum().Values().Len(); i++ { + for i := range fd.Enum().Values().Len() { enumName := string(fd.Enum().Values().Get(i).Name()) enumReverseSplit := strings.Split(enumName, "_") - for i := 0; i < len(enumReverseSplit)/2; i++ { + for i := range len(enumReverseSplit) / 2 { j := len(enumReverseSplit) - i - 1 enumReverseSplit[i], enumReverseSplit[j] = enumReverseSplit[j], enumReverseSplit[i] } @@ -148,10 +150,16 @@ func protoReflect(m proto.Message, opts ProtoReflectOptions) AIPFilterOptions { return res } +// ProtoReflect generates AIP filter options from a protobuf message using default settings. +// It uses reflection to analyze the message structure and create appropriate filter identifiers +// for each field based on their protobuf types. func ProtoReflect(m proto.Message) AIPFilterOptions { return protoReflect(m, ProtoReflectOptions{}) } +// ProtoReflectWithOpts generates AIP filter options from a protobuf message using custom options. +// It allows for more control over the reflection process, including field exclusion and +// custom column name mapping. func ProtoReflectWithOpts(m proto.Message, opts ProtoReflectOptions) AIPFilterOptions { return protoReflect(m, opts) } diff --git a/pika_aip_filter_proto_test.go b/pika_aip_filter_proto_test.go index 39d39ad..2be6cac 100644 --- a/pika_aip_filter_proto_test.go +++ b/pika_aip_filter_proto_test.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -153,13 +153,13 @@ func TestComplete3AIP160(t *testing.T) { qs = Q[protoModel4](psql) _, err = qs.AIP160(filter, opts) require.NotNil(t, err) - require.Equal(t, "identifier non_existent is not allowed", err.Error()) + require.Equal(t, "identifier is not allowed: non_existent", err.Error()) filter = `bool = null` qs = Q[protoModel4](psql) _, err = qs.AIP160(filter, opts) require.NotNil(t, err) - require.Equal(t, "type NULL is not accepted for identifier bool", err.Error()) + require.Equal(t, "type is not accepted for identifier: NULL for identifier bool", err.Error()) filter = `status = 0` qs = Q[protoModel4](psql) @@ -192,13 +192,13 @@ func TestComplete3AIP160(t *testing.T) { qs = Q[protoModel4](psql) _, err = qs.AIP160(filter, opts) require.NotNil(t, err) - require.Equal(t, "type STRING is not accepted for identifier status", err.Error()) + require.Equal(t, "type is not accepted for identifier: STRING for identifier status", err.Error()) filter = `status = 99` qs = Q[protoModel4](psql) _, err = qs.AIP160(filter, opts) require.NotNil(t, err) - require.Equal(t, "value 99 is not accepted for identifier status", err.Error()) + require.Equal(t, "value is not accepted for identifier: 99 for identifier status", err.Error()) } func TestComplete3Exclude(t *testing.T) { @@ -224,7 +224,7 @@ func TestComplete3Exclude(t *testing.T) { qs = Q[protoModel4](psql) _, err = qs.AIP160(filter, opts) require.NotNil(t, err) - require.Equal(t, "identifier nullable_int is not allowed", err.Error()) + require.Equal(t, "identifier is not allowed: nullable_int", err.Error()) } func TestMultipleSameFieldOr(t *testing.T) { diff --git a/pika_aip_filter_psql_test.go b/pika_aip_filter_psql_test.go index 5d4e595..573a646 100644 --- a/pika_aip_filter_psql_test.go +++ b/pika_aip_filter_psql_test.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -76,7 +76,7 @@ func TestAIP160SimpleEqualsAcceptableIdentifier(t *testing.T) { AcceptableIdentifiers: []string{"non_nullable"}, }) require.NotNil(t, err) - require.EqualError(t, err, "identifier invalid is not allowed") + require.EqualError(t, err, "identifier is not allowed: invalid") } func TestAIP160EqualsAndNull(t *testing.T) { diff --git a/pika_page_token.go b/pika_page_token.go index 03b39d2..3b9ccc9 100644 --- a/pika_page_token.go +++ b/pika_page_token.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -7,8 +7,27 @@ import ( "encoding/base64" "encoding/json" "fmt" + "math" + + "github.com/pkg/errors" +) + +const ( + // MaxPageSize represents the maximum allowed page size for pagination + MaxPageSize = 100 +) + +// Static errors for err113 compliance +var ( + ErrPageTokenDecode = errors.New("failed to decode page token, make sure it is from a previous request") + ErrPageSizeTooSmall = errors.New("page size cannot be less than 1") + ErrPageSizeTooLarge = errors.New("page size cannot be greater than 100") + ErrOffsetTooLarge = errors.New("offset value too large") + ErrPageSizeValueTooLarge = errors.New("page size value too large") ) +// PageToken represents a pagination token that encodes the current state of pagination +// including offset, filter, ordering, and page size information. type PageToken[T any] struct { QuerySet[T] `json:"-"` @@ -18,6 +37,8 @@ type PageToken[T any] struct { PageSize uint `json:"page_size"` } +// Paginatable is an interface that defines the methods required for pagination support. +// Types implementing this interface can be used with pagination functionality. type Paginatable interface { GetFilter() string GetOrderBy() string @@ -25,6 +46,7 @@ type Paginatable interface { GetPageToken() string } +// PageRequest represents a request for paginated data with filtering and ordering options. type PageRequest struct { // Filter is a filter expression that restricts the results to return. Filter string @@ -39,6 +61,7 @@ type PageRequest struct { PageToken string } +// NewPageToken creates a new PageToken instance for the given type T. func NewPageToken[T any]() *PageToken[T] { return &PageToken[T]{} } @@ -61,7 +84,7 @@ func (p *PageToken[T]) Encode() (string, error) { func (p *PageToken[T]) Decode(s string) error { data, err := base64.URLEncoding.DecodeString(s) if err != nil { - return fmt.Errorf("failed to decode page token, make sure it is from a previous request") + return ErrPageTokenDecode } return json.Unmarshal(data, p) } @@ -82,12 +105,22 @@ func (p *PageToken[T]) pageToken(b QuerySet[T], options AIPFilterOptions) (Query // Return error if page size is less than 1 if p.PageSize < 1 { - return nil, fmt.Errorf("page size cannot be less than 1") + return nil, ErrPageSizeTooSmall } // Return error if page size is greater than 100 - if p.PageSize > 100 { - return nil, fmt.Errorf("page size cannot be greater than 100") + if p.PageSize > MaxPageSize { + return nil, ErrPageSizeTooLarge + } + + // Check for potential integer overflow when converting uint to int + if p.Offset > math.MaxInt { + return nil, ErrOffsetTooLarge + } + + // Check for potential integer overflow when converting PageSize uint to int + if p.PageSize > math.MaxInt { + return nil, ErrPageSizeValueTooLarge } b.Offset(int(p.Offset)) @@ -103,18 +136,22 @@ func (p *PageToken[T]) pageToken(b QuerySet[T], options AIPFilterOptions) (Query return b, nil } +// GetFilter returns the filter expression for the page request. func (p *PageRequest) GetFilter() string { return p.Filter } +// GetOrderBy returns the order by expression for the page request. func (p *PageRequest) GetOrderBy() string { return p.OrderBy } +// GetPageSize returns the maximum number of results to return for the page request. func (p *PageRequest) GetPageSize() int32 { return p.PageSize } +// GetPageToken returns the page token for the next request. func (p *PageRequest) GetPageToken() string { return p.PageToken } diff --git a/pika_psql.go b/pika_psql.go index eef03f7..59f2847 100644 --- a/pika_psql.go +++ b/pika_psql.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -20,6 +20,24 @@ import ( _ "github.com/lib/pq" ) +const ( + // expectedFieldParts represents the expected number of parts when splitting a field name by dot + expectedFieldParts = 2 + // maxHintParts represents the maximum number of hint parts in a filter key + maxHintParts = 2 +) + +// Static errors for err113 compliance +var ( + ErrTooManyArguments = errors.New("too many arguments (count should be one pointer or none)") + ErrPageSizeNegative = errors.New("page size cannot be negative") + ErrCountNegative = errors.New("count cannot be negative") + ErrMissingArgument = errors.New("missing argument") + ErrInvalidOperator = errors.New("invalid operator") + ErrInvalidKey = errors.New("invalid key") + ErrBothModelsNil = errors.New("modelFirst and modelSecond are all nil, this is not allowed") +) + // Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing // either type to be used interchangeably. // @@ -31,13 +49,13 @@ type Queryable interface { sqlx.QueryerContext sqlx.Preparer - GetContext(context.Context, interface{}, string, ...interface{}) error - MustExecContext(context.Context, string, ...interface{}) sql.Result - NamedExecContext(context.Context, string, interface{}) (sql.Result, error) - PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) - PreparexContext(context.Context, string) (*sqlx.Stmt, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row - SelectContext(context.Context, interface{}, string, ...interface{}) error + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result + NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) + PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) + PreparexContext(ctx context.Context, query string) (*sqlx.Stmt, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error } var ( @@ -45,8 +63,11 @@ var ( _ Queryable = (*sqlx.Tx)(nil) ) +// PostgreSQL represents a PostgreSQL database connection with transaction support. +// It wraps sqlx.DB for regular database operations and sqlx.Tx for transactional operations. type PostgreSQL struct { *connBase + db *sqlx.DB tx *sqlx.Tx } @@ -55,12 +76,14 @@ type basePsql[T any] struct { *AIPFilter[T] *PageToken[T] *base - //nolint:structcheck // false positive + psql *PostgreSQL } +// CreateOption represents options for database insert operations. type CreateOption byte +// InsertOnConflictionDoNothing specifies that conflicts should be ignored during inserts. const InsertOnConflictionDoNothing CreateOption = 1 << iota // NewPostgreSQL returns a new PostgreSQL instance. @@ -77,6 +100,7 @@ func NewPostgreSQL(connectionString string) (*PostgreSQL, error) { }, nil } +// NewPostgreSQLFromDB creates a new PostgreSQL instance from an existing sqlx.DB connection. func NewPostgreSQLFromDB(db *sqlx.DB) *PostgreSQL { return &PostgreSQL{ connBase: &connBase{}, @@ -126,6 +150,7 @@ func (p *PostgreSQL) Rollback() error { return nil } +// Queryable returns the current queryable interface (either DB or transaction). func (p *PostgreSQL) Queryable() Queryable { if p.tx != nil { return p.tx @@ -134,14 +159,19 @@ func (p *PostgreSQL) Queryable() Queryable { return p.db } +// DB returns the underlying sqlx.DB instance. func (p *PostgreSQL) DB() *sqlx.DB { return p.db } +// Close closes the database connection. func (p *PostgreSQL) Close() error { return p.db.Close() } +// PSQLQuery creates a new QuerySet for PostgreSQL database operations. +// It initializes the query builder with metadata for the given type T and sets up +// table name resolution based on model metadata or pluralized model names. func PSQLQuery[T any](p *PostgreSQL) QuerySet[T] { b := &basePsql[T]{ AIPFilter: NewAIPFilter[T](), @@ -402,7 +432,7 @@ func (b *basePsql[T]) Count(ctx context.Context) (int, error) { filterStatement, args := b.queryWithFilters() preSelect := b.psqlSelectList(b.excludeColumns, b.includeColumns, false) // Strip preSelect from filterStatement - filterStatement = strings.Replace(filterStatement, preSelect, "", -1) + filterStatement = strings.ReplaceAll(filterStatement, preSelect, "") b.ignoreLimit = origIgnoreLimit b.ignoreOffset = origIgnoreOffset b.ignoreOrderBy = origIgnoreOrderBy @@ -489,7 +519,7 @@ func (b *basePsql[T]) DeleteQuery() (string, []interface{}) { modelName := b.metadata[pikaMetadataModelName] filterStatement, args := b.filterStatement() - filterStatement = strings.Replace(filterStatement, fmt.Sprintf("\"%s\".", modelName), "", -1) + filterStatement = strings.ReplaceAll(filterStatement, fmt.Sprintf("\"%s\".", modelName), "") q := fmt.Sprintf("DELETE FROM \"%s\"", b.metadata[PikaMetadataTableName]) q += filterStatement @@ -536,7 +566,7 @@ func (b *basePsql[T]) AIP160(filter string, options AIPFilterOptions) (QuerySet[ // Page tokens for gRPC func (b *basePsql[T]) GetPage(ctx context.Context, paginatable Paginatable, options AIPFilterOptions, countPointer ...*int) ([]*T, string, error) { if len(countPointer) > 1 { - return nil, "", fmt.Errorf("too many arguments (count should be one pointer or none)") + return nil, "", ErrTooManyArguments } if b.err != nil { @@ -545,7 +575,7 @@ func (b *basePsql[T]) GetPage(ctx context.Context, paginatable Paginatable, opti // Only decode if token is not empty if paginatable.GetPageToken() != "" { - err := b.PageToken.Decode(paginatable.GetPageToken()) + err := b.Decode(paginatable.GetPageToken()) if err != nil { return nil, "", err } @@ -554,7 +584,14 @@ func (b *basePsql[T]) GetPage(ctx context.Context, paginatable Paginatable, opti b.PageToken.Offset = 0 b.PageToken.Filter = paginatable.GetFilter() b.PageToken.OrderBy = paginatable.GetOrderBy() - b.PageToken.PageSize = uint(paginatable.GetPageSize()) + + // Validate page size to prevent integer overflow + pageSize := paginatable.GetPageSize() + if pageSize < 0 { + return nil, "", ErrPageSizeNegative + } + // Use proper bounds checking for int32 to uint conversion + b.PageSize = uint(pageSize) } qs, err := b.pageToken(b, options) @@ -579,12 +616,18 @@ func (b *basePsql[T]) GetPage(ctx context.Context, paginatable Paginatable, opti *countPointer[0] = count } + // Validate count to prevent integer overflow + if count < 0 { + return nil, "", ErrCountNegative + } + // If no more results after this page, return empty page token + // Safe conversion since we've validated count >= 0 if b.PageToken.Offset >= uint(count) { return result, "", nil } - tk, err := b.PageToken.Encode() + tk, err := b.Encode() if err != nil { return nil, "", err } @@ -641,15 +684,13 @@ func (b *basePsql[T]) Include(includes ...string) QuerySet[T] { // Return args, used for reflection func (b *basePsql[T]) GetArgs() *orderedmap.OrderedMap[string, interface{}] { return b.args -} - -// Return current table and module name, used for reflection +} // Return current table and module name, used for reflection func (b *basePsql[T]) GetModel() (string, string) { var x T modelName := reflect.TypeOf(x).Name() tableName := modelName ref := reflect.ValueOf(x) - for i := 0; i < ref.NumField(); i++ { + for i := range ref.NumField() { field := ref.Type().Field(i) if strings.Compare(field.Name, PikaMetadataTableName) == 0 { tableName = field.Tag.Get("pika") @@ -791,7 +832,7 @@ func (b *basePsql[T]) filterStatement() (string, []any) { // If mapping found, replace with numbered parameter v = fmt.Sprintf("$%d", mapping[noWildcard[1:]]) } else { - b.err = fmt.Errorf("missing argument: %s", noWildcard) + b.err = fmt.Errorf("%w: %s", ErrMissingArgument, noWildcard) return "", nil } } @@ -805,7 +846,7 @@ func (b *basePsql[T]) filterStatement() (string, []any) { if strings.Contains(k, "__") { parts := strings.Split(k, "__") k = parts[0] - op := fmt.Sprintf("__%s", parts[1]) + op := "__" + parts[1] // IN requires the value wrapped in ANY // as go-pika sends the value as a slice @@ -856,12 +897,12 @@ func (b *basePsql[T]) filterStatement() (string, []any) { if op == HintLike || op == HintNotLike || op == HintILike || op == HintNotILike { // If a start wildcard was found, then add a prefix if startWildcard { - v = fmt.Sprintf("'%%' || %s", v) + v = "'%' || " + v } // If an end wildcard was found, then add a suffix if endWildcard { - v = fmt.Sprintf("%s || '%%'", v) + v = v + " || '%'" } } @@ -872,8 +913,8 @@ func (b *basePsql[T]) filterStatement() (string, []any) { } extraHintOp := op - if len(parts) > 2 { - extraHintOp = fmt.Sprintf("__%s", parts[2]) + if len(parts) > maxHintParts { + extraHintOp = "__" + parts[2] } // If AND then set andOr to AND regardless of filter.innerOr @@ -905,7 +946,7 @@ func (b *basePsql[T]) filterStatement() (string, []any) { var ok bool operator, ok = operators[op] if !ok { - b.err = fmt.Errorf("invalid operator: %s", operator) + b.err = fmt.Errorf("%w: %s", ErrInvalidOperator, operator) return "", nil } } @@ -918,8 +959,8 @@ func (b *basePsql[T]) filterStatement() (string, []any) { if strings.Contains(clean, ".") { // Split by dot, then join with quotes parts := strings.Split(clean, ".") - if len(parts) != 2 { - b.err = fmt.Errorf("invalid key: %s", k) + if len(parts) != expectedFieldParts { + b.err = fmt.Errorf("%w: %s", ErrInvalidKey, k) return "", nil } finalK = fmt.Sprintf("\"%s\".\"%s\"", parts[0], parts[1]) @@ -1048,7 +1089,7 @@ func (b *basePsql[T]) psqlSelectList(excludeColumns []string, includeColumns []s columns := make([]column, 0, ref.NumField()) // Iterate through fields to get tags - for i := 0; i < ref.NumField(); i++ { + for i := range ref.NumField() { field := ref.Type().Field(i) tag := field.Tag.Get("db") // By default, pikaTag = tag @@ -1091,8 +1132,8 @@ func (b *basePsql[T]) psqlSelectList(excludeColumns []string, includeColumns []s var selectColumns []string for _, column := range columns { if column.pika != "" { - values := strings.SplitN(column.pika, ".", 2) - if len(values) == 2 { + values := strings.SplitN(column.pika, ".", expectedFieldParts) + if len(values) == expectedFieldParts { if val, ok := b.replaceFields[values[0]]; ok { // Need to replace fields from other tables with associated model prefixs // These fields are defined in the current model, but their values are from other tables @@ -1113,7 +1154,7 @@ func (b *basePsql[T]) psqlSelectList(excludeColumns []string, includeColumns []s return strings.Join(selectColumns, ", ") } - selectStr := fmt.Sprintf("SELECT %s", strings.Join(selectColumns, ", ")) + selectStr := "SELECT " + strings.Join(selectColumns, ", ") q := fmt.Sprintf("%s %s", selectStr, strings.Join(fromStrs, ",")) return q @@ -1148,7 +1189,7 @@ func (b *basePsql[T]) psqlCreateQuery(value *T, options ...CreateOption) (string // Iterate through fields to get tags xi := 0 - for i := 0; i < ref.Elem().NumField(); i++ { + for i := range ref.Elem().NumField() { field := ref.Elem().Type().Field(i) tag := field.Tag.Get("db") // Ignore "-" tags (or empty tags) @@ -1186,7 +1227,7 @@ func (b *basePsql[T]) psqlCreateQuery(value *T, options ...CreateOption) (string selectList := b.psqlSelectList(b.excludeColumns, b.includeColumns, true) // Remove the model name prefix from the select list // since we are inserting into the table - selectList = strings.Replace(selectList, fmt.Sprintf("\"%s\".", modelName), "", -1) + selectList = strings.ReplaceAll(selectList, fmt.Sprintf("\"%s\".", modelName), "") conflict := "" if InsertOnConflictionDoNothing&getOption(options...) != 0 { @@ -1196,7 +1237,7 @@ func (b *basePsql[T]) psqlCreateQuery(value *T, options ...CreateOption) (string // Convert value to arguments args := make([]interface{}, 0, ref.Elem().NumField()) - for i := 0; i < ref.Elem().NumField(); i++ { + for i := range ref.Elem().NumField() { field := ref.Elem().Type().Field(i) tag := fmt.Sprintf("\"%s\"", field.Tag.Get("db")) @@ -1224,7 +1265,7 @@ func (b *basePsql[T]) psqlUpdateQuery(value *T) (string, []any) { // Iterate through fields to get tags xi := 0 - for i := 0; i < ref.Elem().NumField(); i++ { + for i := range ref.Elem().NumField() { field := ref.Elem().Type().Field(i) tag := field.Tag.Get("db") // Ignore "-" tags (or empty tags) @@ -1270,7 +1311,7 @@ func (b *basePsql[T]) psqlUpdateQuery(value *T) (string, []any) { selectList := b.psqlSelectList(b.excludeColumns, b.includeColumns, true) // Remove the model name prefix from the select list // since we are inserting into the table - selectList = strings.Replace(selectList, fmt.Sprintf("\"%s\".", modelName), "", -1) + selectList = strings.ReplaceAll(selectList, fmt.Sprintf("\"%s\".", modelName), "") q := fmt.Sprintf("UPDATE \"%s\" SET ", tableName) // Add columns to update @@ -1282,11 +1323,11 @@ func (b *basePsql[T]) psqlUpdateQuery(value *T) (string, []any) { } // Add where clause - filterStatement = strings.Replace(filterStatement, fmt.Sprintf("\"%s\".", modelName), "", -1) + filterStatement = strings.ReplaceAll(filterStatement, fmt.Sprintf("\"%s\".", modelName), "") q += fmt.Sprintf("%s RETURNING %s", filterStatement, selectList) // Convert value to arguments - for i := 0; i < ref.Elem().NumField(); i++ { + for i := range ref.Elem().NumField() { field := ref.Elem().Type().Field(i) tag := fmt.Sprintf("\"%s\"", field.Tag.Get("db")) @@ -1318,7 +1359,7 @@ func (b *basePsql[T]) commonJoin(joinType string, modelFirst, modelSecond interf } if modelFirst == nil && modelSecond == nil { - b.err = fmt.Errorf("modelFirst and modelSecond are all nil, this is not allowed") + b.err = ErrBothModelsNil return b } @@ -1390,7 +1431,7 @@ func getQuerySetInfo(val interface{}) (string, string) { mname := reflect.TypeOf(val).Name() tname := mname ref := reflect.ValueOf(val) - for i := 0; i < ref.NumField(); i++ { + for i := range ref.NumField() { field := ref.Type().Field(i) if strings.Compare(field.Name, PikaMetadataTableName) == 0 { tname = field.Tag.Get("pika") @@ -1413,15 +1454,13 @@ func isTarget(val interface{}) bool { } // Replace the old place holder with new ones -// -//nolint:predeclared -func replacePlaceHolder(query string, old, new []int) string { - if len(old) != len(new) { +func replacePlaceHolder(query string, old, newIdx []int) string { + if len(old) != len(newIdx) { return query } for idx := range old { - query = strings.Replace(query, fmt.Sprintf("$%d", old[idx]), fmt.Sprintf("$%d", new[idx]), 1) + query = strings.Replace(query, fmt.Sprintf("$%d", old[idx]), fmt.Sprintf("$%d", newIdx[idx]), 1) } return query @@ -1434,7 +1473,7 @@ func generateRangeSlice(start, length int) *[]int { } ret := make([]int, length) - for idx := 0; idx < length; idx++ { + for idx := range length { ret[idx] = start + idx } return &ret diff --git a/pika_psql_experimental.go b/pika_psql_experimental.go index b5396b0..b6a9b1a 100644 --- a/pika_psql_experimental.go +++ b/pika_psql_experimental.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -7,6 +7,13 @@ import ( "context" "fmt" "reflect" + + "github.com/pkg/errors" +) + +// Static errors for err113 compliance +var ( + ErrIDNotFound = errors.New("id not found") ) func (b *basePsql[T]) findID(x *T) any { @@ -45,7 +52,7 @@ func (b *basePsql[T]) F(keyval ...any) QuerySet[T] { func (b *basePsql[T]) D(ctx context.Context, x *T) error { id := b.findID(x) if id == nil { - return fmt.Errorf("id not found") + return ErrIDNotFound } qs := b.F("id", id) @@ -65,7 +72,7 @@ func (b *basePsql[T]) Transaction(ctx context.Context) (QuerySet[T], error) { func (b *basePsql[T]) U(ctx context.Context, x *T) error { id := b.findID(x) if id == nil { - return fmt.Errorf("id not found") + return ErrIDNotFound } qs := b.F("id", id) diff --git a/pika_psql_experimental_test.go b/pika_psql_experimental_test.go index 84950d7..7f8605c 100644 --- a/pika_psql_experimental_test.go +++ b/pika_psql_experimental_test.go @@ -1,3 +1,5 @@ +// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-License-Identifier: Apache-2.0 package pika import ( diff --git a/pika_psql_test.go b/pika_psql_test.go index a1b5ff7..9ad4c7a 100644 --- a/pika_psql_test.go +++ b/pika_psql_test.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika diff --git a/testproto/gen.go b/testproto/gen.go index adc0981..3321f47 100644 --- a/testproto/gen.go +++ b/testproto/gen.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 //go:generate protoc --go_opt=paths=source_relative --go_out=. test.proto diff --git a/utils.go b/utils.go index c5fea56..cc17940 100644 --- a/utils.go +++ b/utils.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2023-2024, Ctrl IQ, Inc. All rights reserved +// SPDX-FileCopyrightText: Copyright (c) 2023-2025, CTRL IQ, Inc. All rights reserved // SPDX-License-Identifier: Apache-2.0 package pika @@ -9,11 +9,24 @@ import ( "reflect" "strings" + "github.com/pkg/errors" "github.com/sirupsen/logrus" orderedmap "github.com/wk8/go-ordered-map/v2" "go.ciq.dev/pika/parser" ) +const ( + // expectedQueryParts represents the expected number of parts when splitting a query by '=' + expectedQueryParts = 2 +) + +// Static errors for err113 compliance +var ( + ErrInvalidFilter = errors.New("invalid filter") + ErrFilterKeyContainsExclamation = errors.New("filter key contains exclamation mark") +) + +//nolint:revive var ( logger *logrus.Logger pikaMetadataModelName = "PikaMetadataModelName" @@ -171,7 +184,6 @@ type pikaFiltering struct { innerOr bool } -//nolint:structcheck type base struct { filters []pikaFiltering args *orderedmap.OrderedMap[string, interface{}] @@ -233,12 +245,12 @@ func (b *base) filter(innerOr bool, or bool, queries ...string) { newFilters := orderedmap.New[string, string]() for _, query := range queries { split := strings.Split(query, "=") - if len(split) != 2 { - b.err = fmt.Errorf("invalid filter: %s", query) + if len(split) != expectedQueryParts { + b.err = fmt.Errorf("%w: %s", ErrInvalidFilter, query) break } if strings.Contains(split[0], "!") { - b.err = fmt.Errorf("filter key contains exclamation mark: %s", query) + b.err = fmt.Errorf("%w: %s", ErrFilterKeyContainsExclamation, query) break } newFilters.Set(findEmptyForKey(split[0], newFilters), split[1]) @@ -283,7 +295,7 @@ func (c *connBase) TableAlias(src string, dst string) { c.tableAlias = make(map[string]string) } if _, ok := c.tableAlias[src]; ok { - panic(fmt.Sprintf("duplicate table alias: %s", src)) + panic("duplicate table alias: " + src) } c.tableAlias[src] = dst } @@ -300,7 +312,7 @@ func getPikaMetadata[T any]() map[string]string { metadata[pikaMetadataModelName] = modelName // Iterate through fields to get tags - for i := 0; i < ref.NumField(); i++ { + for i := range ref.NumField() { field := ref.Type().Field(i) // We only care about fields starting with "Pika" @@ -309,7 +321,7 @@ func getPikaMetadata[T any]() map[string]string { tag := field.Tag.Get("pika") if _, ok := metadata[field.Name]; ok { - panic(fmt.Sprintf("duplicate Pika metadata field: %s", field.Name)) + panic("duplicate Pika metadata field: " + field.Name) } metadata[field.Name] = tag } @@ -323,7 +335,7 @@ func getPikaMetadata[T any]() map[string]string { } if _, ok := metadata[field.Name]; ok { - panic(fmt.Sprintf("duplicate Pika database field: %s", field.Name)) + panic("duplicate Pika database field: " + field.Name) } metadata[tag] = field.Type.String()