diff --git a/CHANGELOG.md b/CHANGELOG.md index cf18092d9..044da4f46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/) and this p ### Changed +- `pkg/luhn`: Replaced duplicate `computeLuhnCheckDigit` implementations in the `fr` and `it` regimes with a shared `CheckDigit` function. +- `org`, `bill`, `regimes`, `addons/it/sdi`: Replaced duplicate `hasTaxIDCode` helpers with a shared `Party.HasTaxIDCode()` method. - `bill`: `Invoice.Invert()` returns an error if the invoice has the `bypass` tag. - `num`: `AmountFromString` now limits precision to 18 significant digits. - `tax`: Added `$defs` and `$refs` to the `tax.RegimeCode` JSON schema diff --git a/addons/it/sdi/bill.go b/addons/it/sdi/bill.go index 87f8670f9..0fddb5963 100644 --- a/addons/it/sdi/bill.go +++ b/addons/it/sdi/bill.go @@ -160,7 +160,7 @@ func validateCustomer(value interface{}) error { ), validation.Field(&customer.Identities, validation.When( - isItalianParty(customer) && !hasTaxIDCode(customer), + isItalianParty(customer) && !customer.HasTaxIDCode(), org.RequireIdentityKey(it.IdentityKeyFiscalCode), ), validation.Skip, @@ -287,10 +287,6 @@ func validateItalianTelephone(value any) error { ) } -func hasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - func hasFiscalCode(party *org.Party) bool { if party == nil { return false diff --git a/bill/invoice.go b/bill/invoice.go index 95134d4e4..de7298500 100644 --- a/bill/invoice.go +++ b/bill/invoice.go @@ -214,17 +214,13 @@ func validateInvoiceCustomer(value any) error { return validation.ValidateStruct(p, validation.Field(&p.Name, validation.When( - partyHasTaxIDCode(p), + p.HasTaxIDCode(), validation.Required, ), ), ) } -func partyHasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - // Invert effectively reverses the invoice by inverting the sign of all quantity // or amount values. Caution should be taken when using this method as // advances will also be inverted, while payment terms will remain the same, diff --git a/org/party.go b/org/party.go index 4c03a0f4a..df538a150 100644 --- a/org/party.go +++ b/org/party.go @@ -50,6 +50,11 @@ type Party struct { Meta cbc.Meta `json:"meta,omitempty" jsonschema:"title=Meta"` } +// HasTaxIDCode returns true if the party has a tax identity with a non-empty code. +func (p *Party) HasTaxIDCode() bool { + return p != nil && p.TaxID != nil && p.TaxID.Code != "" +} + // Calculate will perform basic normalization of the party's data without // using any tax regime or addon. func (p *Party) Calculate() error { diff --git a/org/party_test.go b/org/party_test.go index 19b0d735d..3a9ca14f2 100644 --- a/org/party_test.go +++ b/org/party_test.go @@ -137,6 +137,20 @@ func TestPartyValidation(t *testing.T) { assert.NoError(t, party.Validate()) assert.Equal(t, "DE", party.GetRegime().String()) }) + t.Run("has tax id code", func(t *testing.T) { + var nilParty *org.Party + assert.False(t, nilParty.HasTaxIDCode()) + + assert.False(t, (&org.Party{Name: "Test"}).HasTaxIDCode()) + + assert.False(t, (&org.Party{ + TaxID: &tax.Identity{Country: "ES"}, + }).HasTaxIDCode()) + + assert.True(t, (&org.Party{ + TaxID: &tax.Identity{Country: "ES", Code: "B85905495"}, + }).HasTaxIDCode()) + }) t.Run("with regime and bad code", func(t *testing.T) { party := org.Party{ Regime: tax.WithRegime("DE"), diff --git a/pkg/luhn/luhn.go b/pkg/luhn/luhn.go index 8e4528f10..4a5210a8f 100644 --- a/pkg/luhn/luhn.go +++ b/pkg/luhn/luhn.go @@ -4,6 +4,7 @@ package luhn import ( "regexp" + "strconv" "github.com/invopop/gobl/cbc" ) @@ -40,3 +41,23 @@ func Check(code cbc.Code) bool { return checksum%10 == 0 } + +// CheckDigit computes the Luhn check digit for the given numeric string. +// The caller is responsible for ensuring the input contains only ASCII digit +// characters; no validation is performed on the input. +func CheckDigit(number string) string { + sum := 0 + pos := 0 + for i := len(number) - 1; i >= 0; i-- { + digit := int(number[i] - '0') + if pos%2 == 0 { + digit *= 2 + if digit > 9 { + digit -= 9 + } + } + sum += digit + pos++ + } + return strconv.Itoa((10 - sum%10) % 10) +} diff --git a/pkg/luhn/luhn_test.go b/pkg/luhn/luhn_test.go index ef7034176..13d86c251 100644 --- a/pkg/luhn/luhn_test.go +++ b/pkg/luhn/luhn_test.go @@ -51,3 +51,29 @@ func TestCheck(t *testing.T) { }) } } + +func TestCheckDigit(t *testing.T) { + t.Parallel() + tests := []struct { + name string + number string + want string + }{ + {name: "single digit", number: "0", want: "0"}, + {name: "credit card base", number: "411111111111111", want: "1"}, + {name: "luhn example", number: "7992739871", want: "3"}, + {name: "italian VAT", number: "0271580010", want: "4"}, + {name: "french SIREN", number: "73282932", want: "0"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := luhn.CheckDigit(tt.number) + assert.Equal(t, tt.want, got) + // Verify consistency: number + check digit should pass Check. + full := cbc.Code(tt.number + got) + assert.True(t, luhn.Check(full), "number+check digit should pass Check: %s", full) + }) + } +} diff --git a/regimes/be/invoices.go b/regimes/be/invoices.go index 15b4582ea..cc31308be 100644 --- a/regimes/be/invoices.go +++ b/regimes/be/invoices.go @@ -32,7 +32,7 @@ func validateInvoiceSupplier(value any) error { ), validation.Field(&p.Identities, validation.When( - !hasTaxIDCode(p), + !p.HasTaxIDCode(), org.RequireIdentityType(IdentityTypeBCE), ), validation.Skip, @@ -40,10 +40,6 @@ func validateInvoiceSupplier(value any) error { ) } -func hasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - func hasIdentityBCE(party *org.Party) bool { if party == nil || len(party.Identities) == 0 { return false diff --git a/regimes/de/invoices.go b/regimes/de/invoices.go index d3dc0ce7d..42c7bea11 100644 --- a/regimes/de/invoices.go +++ b/regimes/de/invoices.go @@ -37,7 +37,7 @@ func validateInvoiceSupplier(value any) error { ), validation.Field(&p.Identities, validation.When( - !hasTaxIDCode(p), + !p.HasTaxIDCode(), org.RequireIdentityKey(IdentityKeyTaxNumber), ), validation.Skip, @@ -49,10 +49,6 @@ func isSimplified(inv *bill.Invoice) bool { return inv.HasTags(tax.TagSimplified) } -func hasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - func hasIdentityTaxNumber(party *org.Party) bool { if party == nil || len(party.Identities) == 0 { return false diff --git a/regimes/dk/invoices.go b/regimes/dk/invoices.go index f1e7148db..caf38cc2f 100644 --- a/regimes/dk/invoices.go +++ b/regimes/dk/invoices.go @@ -32,7 +32,7 @@ func validateInvoiceSupplier(value any) error { ), validation.Field(&p.Identities, validation.When( - !hasTaxIDCode(p), + !p.HasTaxIDCode(), org.RequireIdentityType(IdentityTypeCVR), ), validation.Skip, @@ -40,10 +40,6 @@ func validateInvoiceSupplier(value any) error { ) } -func hasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - func hasIdentityCVR(party *org.Party) bool { if party == nil || len(party.Identities) == 0 { return false diff --git a/regimes/fr/invoices.go b/regimes/fr/invoices.go index 3750ad21e..f37084742 100644 --- a/regimes/fr/invoices.go +++ b/regimes/fr/invoices.go @@ -32,7 +32,7 @@ func validateInvoiceSupplier(value any) error { ), validation.Field(&p.Identities, validation.When( - !hasTaxIDCode(p), + !p.HasTaxIDCode(), org.RequireIdentityType(IdentityTypeSIREN, IdentityTypeSIRET), ), validation.Skip, @@ -40,10 +40,6 @@ func validateInvoiceSupplier(value any) error { ) } -func hasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - func hasSupplierIdentity(party *org.Party) bool { if party == nil || len(party.Identities) == 0 { return false diff --git a/regimes/fr/tax_identity.go b/regimes/fr/tax_identity.go index 24f97332b..6a8f5e241 100644 --- a/regimes/fr/tax_identity.go +++ b/regimes/fr/tax_identity.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/invopop/gobl/cbc" + "github.com/invopop/gobl/pkg/luhn" "github.com/invopop/gobl/tax" "github.com/invopop/validation" ) @@ -98,32 +99,9 @@ func validateSIRENTaxCode(value any) error { base := str[:8] chk := str[8:] - v := computeLuhnCheckDigit(base) - if chk != v { + if luhn.CheckDigit(base) != chk { return errors.New("checksum mismatch") } return nil } - -// TODO: refactor this into a shareable method. -func computeLuhnCheckDigit(number string) string { - sum := 0 - pos := 0 - - for i := len(number) - 1; i >= 0; i-- { - digit := int(number[i] - '0') - - if pos%2 == 0 { - digit *= 2 - if digit > 9 { - digit -= 9 - } - } - - sum += digit - pos++ - } - - return strconv.FormatInt(int64((10-(sum%10))%10), 10) -} diff --git a/regimes/it/tax_identity.go b/regimes/it/tax_identity.go index df0c71602..c453137cf 100644 --- a/regimes/it/tax_identity.go +++ b/regimes/it/tax_identity.go @@ -7,9 +7,9 @@ package it import ( "errors" - "strconv" "github.com/invopop/gobl/cbc" + "github.com/invopop/gobl/pkg/luhn" "github.com/invopop/gobl/tax" "github.com/invopop/validation" ) @@ -49,32 +49,9 @@ func validateTaxCode(value interface{}) error { return errors.New("invalid length") } - chk := computeLuhnCheckDigit(str[:10]) - if chk != str[10:] { + if luhn.CheckDigit(str[:10]) != str[10:] { return errors.New("invalid check digit") } return nil } - -// TODO: refactor this into a shareable method. -func computeLuhnCheckDigit(number string) string { - sum := 0 - pos := 0 - - for i := len(number) - 1; i >= 0; i-- { - digit := int(number[i] - '0') - - if pos%2 == 0 { - digit *= 2 - if digit > 9 { - digit -= 9 - } - } - - sum += digit - pos++ - } - - return strconv.FormatInt(int64((10-(sum%10))%10), 10) -} diff --git a/regimes/nl/invoices.go b/regimes/nl/invoices.go index dce602be5..2f178c51f 100644 --- a/regimes/nl/invoices.go +++ b/regimes/nl/invoices.go @@ -32,7 +32,7 @@ func validateInvoiceSupplier(value interface{}) error { ), validation.Field(&p.Identities, validation.When( - !hasTaxIDCode(p), + !p.HasTaxIDCode(), org.RequireIdentityType(IdentityTypeKVK, IdentityTypeOIN), ), validation.Skip, @@ -40,10 +40,6 @@ func validateInvoiceSupplier(value interface{}) error { ) } -func hasTaxIDCode(party *org.Party) bool { - return party != nil && party.TaxID != nil && party.TaxID.Code != "" -} - func hasSupplierIdentity(party *org.Party) bool { if party == nil || len(party.Identities) == 0 { return false