From 326d4d1dd6d5c02b7a9c9befba0856e8d270ded2 Mon Sep 17 00:00:00 2001 From: Tom Fleet Date: Sun, 9 Nov 2025 09:22:02 +0000 Subject: [PATCH 1/4] Add an internal parse package for type safe parsing --- internal/arg/arg.go | 100 +++++++--------------- internal/arg/arg_test.go | 110 +++++------------------- internal/parse/parse.go | 156 +++++++++++++++++++++++++++++++++++ internal/parse/parse_test.go | 141 +++++++++++++++++++++++++++++++ 4 files changed, 347 insertions(+), 160 deletions(-) create mode 100644 internal/parse/parse.go create mode 100644 internal/parse/parse_test.go diff --git a/internal/arg/arg.go b/internal/arg/arg.go index f37db93..237fe2d 100644 --- a/internal/arg/arg.go +++ b/internal/arg/arg.go @@ -15,8 +15,8 @@ import ( "unsafe" "go.followtheprocess.codes/cli/arg" - "go.followtheprocess.codes/cli/flag" "go.followtheprocess.codes/cli/internal/constraints" + "go.followtheprocess.codes/cli/internal/parse" ) // TODO(@FollowTheProcess): LOTS of duplicated stuff with internal/flag. @@ -261,99 +261,99 @@ func (a Arg[T]) Set(str string) error { switch typ := any(*a.value).(type) { case int: - val, err := parseInt[int](0)(str) + val, err := parse.Int(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case int8: - val, err := parseInt[int8](bits8)(str) + val, err := parse.Int8(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case int16: - val, err := parseInt[int16](bits16)(str) + val, err := parse.Int16(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case int32: - val, err := parseInt[int32](bits32)(str) + val, err := parse.Int32(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case int64: - val, err := parseInt[int64](bits64)(str) + val, err := parse.Int64(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case uint: - val, err := parseUint[uint](0)(str) + val, err := parse.Uint(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case uint8: - val, err := parseUint[uint8](bits8)(str) + val, err := parse.Uint8(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case uint16: - val, err := parseUint[uint16](bits16)(str) + val, err := parse.Uint16(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case uint32: - val, err := parseUint[uint32](bits32)(str) + val, err := parse.Uint32(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case uint64: - val, err := parseUint[uint64](bits64)(str) + val, err := parse.Uint64(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) return nil case uintptr: - val, err := parseUint[uint64](bits64)(str) + val, err := parse.Uint64(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -362,7 +362,7 @@ func (a Arg[T]) Set(str string) error { case float32: val, err := parseFloat[float32](bits32)(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -371,7 +371,7 @@ func (a Arg[T]) Set(str string) error { case float64: val, err := parseFloat[float64](bits64)(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -385,7 +385,7 @@ func (a Arg[T]) Set(str string) error { case bool: val, err := strconv.ParseBool(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -394,7 +394,7 @@ func (a Arg[T]) Set(str string) error { case []byte: val, err := hex.DecodeString(strings.TrimSpace(str)) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -403,7 +403,7 @@ func (a Arg[T]) Set(str string) error { case time.Time: val, err := time.Parse(time.RFC3339, str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -412,7 +412,7 @@ func (a Arg[T]) Set(str string) error { case time.Duration: val, err := time.ParseDuration(str) if err != nil { - return errParse(a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, typ, err) } *a.value = *cast[T](&val) @@ -421,7 +421,7 @@ func (a Arg[T]) Set(str string) error { case net.IP: val := net.ParseIP(str) if val == nil { - return errParse(a.name, str, typ, errors.New("invalid IP address")) + return parse.Error(parse.KindArgument, a.name, str, typ, errors.New("invalid IP address")) } *a.value = *cast[T](&val) @@ -516,46 +516,6 @@ func cast[T2, T1 any](v *T1) *T2 { return (*T2)(unsafe.Pointer(v)) } -// errParse is a helper to quickly return a consistent error in the face of flag -// value parsing errors. -func errParse[T flag.Flaggable](name, str string, typ T, err error) error { - return fmt.Errorf( - "arg %q received invalid value %q (expected %T), detail: %w", - name, - str, - typ, - err, - ) -} - -// parseInt is a generic helper to parse all signed integers, given a bit size. -// -// It returns the parsed value or an error. -func parseInt[T constraints.Signed](bits int) func(str string) (T, error) { - return func(str string) (T, error) { - val, err := strconv.ParseInt(str, 0, bits) - if err != nil { - return 0, err - } - - return T(val), nil - } -} - -// parseUint is a generic helper to parse all signed integers, given a bit size. -// -// It returns the parsed value or an error. -func parseUint[T constraints.Unsigned](bits int) func(str string) (T, error) { - return func(str string) (T, error) { - val, err := strconv.ParseUint(str, 0, bits) - if err != nil { - return 0, err - } - - return T(val), nil - } -} - // parseFloat is a generic helper to parse floating point numbers, given a bit size. // // It returns the parsed value or an error. diff --git a/internal/arg/arg_test.go b/internal/arg/arg_test.go index 63b9907..55e508c 100644 --- a/internal/arg/arg_test.go +++ b/internal/arg/arg_test.go @@ -2,11 +2,13 @@ package arg_test import ( "bytes" + "errors" "net" "testing" "time" "go.followtheprocess.codes/cli/internal/arg" + "go.followtheprocess.codes/cli/internal/parse" "go.followtheprocess.codes/test" ) @@ -42,11 +44,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "int" received invalid value "word" (expected int), detail: strconv.ParseInt: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("int8 valid", func(t *testing.T) { @@ -70,11 +68,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "int" received invalid value "word" (expected int8), detail: strconv.ParseInt: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("int16 valid", func(t *testing.T) { @@ -98,11 +92,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "int" received invalid value "word" (expected int16), detail: strconv.ParseInt: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("int32 valid", func(t *testing.T) { @@ -126,11 +116,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "int" received invalid value "word" (expected int32), detail: strconv.ParseInt: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("int64 valid", func(t *testing.T) { @@ -154,11 +140,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "int" received invalid value "word" (expected int64), detail: strconv.ParseInt: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("uint valid", func(t *testing.T) { @@ -182,11 +164,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "uint" received invalid value "word" (expected uint), detail: strconv.ParseUint: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("uint8 valid", func(t *testing.T) { @@ -210,11 +188,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "uint" received invalid value "word" (expected uint8), detail: strconv.ParseUint: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("uint16 valid", func(t *testing.T) { @@ -238,11 +212,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "uint" received invalid value "word" (expected uint16), detail: strconv.ParseUint: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("uint32 valid", func(t *testing.T) { @@ -266,11 +236,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "uint" received invalid value "word" (expected uint32), detail: strconv.ParseUint: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("uint64 valid", func(t *testing.T) { @@ -294,11 +260,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "uint" received invalid value "word" (expected uint64), detail: strconv.ParseUint: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("uintptr valid", func(t *testing.T) { @@ -322,11 +284,7 @@ func TestArgableTypes(t *testing.T) { err = intArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "uintptr" received invalid value "word" (expected uintptr), detail: strconv.ParseUint: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("float32 valid", func(t *testing.T) { @@ -350,11 +308,7 @@ func TestArgableTypes(t *testing.T) { err = floatArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "float" received invalid value "word" (expected float32), detail: strconv.ParseFloat: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("float64 valid", func(t *testing.T) { @@ -378,11 +332,7 @@ func TestArgableTypes(t *testing.T) { err = floatArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "float" received invalid value "word" (expected float64), detail: strconv.ParseFloat: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("bool valid", func(t *testing.T) { @@ -406,11 +356,7 @@ func TestArgableTypes(t *testing.T) { err = boolArg.Set("word") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "bool" received invalid value "word" (expected bool), detail: strconv.ParseBool: parsing "word": invalid syntax`, - ) + test.True(t, errors.Is(err, parse.Err)) }) // No invalid case as all command line args are strings anyway so no real way of @@ -449,11 +395,7 @@ func TestArgableTypes(t *testing.T) { err = byteArg.Set("0xF") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "byte" received invalid value "0xF" (expected []uint8), detail: encoding/hex: invalid byte: U+0078 'x'`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("time.Time valid", func(t *testing.T) { @@ -480,11 +422,7 @@ func TestArgableTypes(t *testing.T) { err = timeArg.Set("not a time") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "time" received invalid value "not a time" (expected time.Time), detail: parsing time "not a time" as "2006-01-02T15:04:05Z07:00": cannot parse "not a time" as "2006"`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("time.Duration valid", func(t *testing.T) { @@ -511,11 +449,7 @@ func TestArgableTypes(t *testing.T) { err = durationArg.Set("not a duration") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "duration" received invalid value "not a duration" (expected time.Duration), detail: time: invalid duration "not a duration"`, - ) + test.True(t, errors.Is(err, parse.Err)) }) t.Run("ip valid", func(t *testing.T) { @@ -539,10 +473,6 @@ func TestArgableTypes(t *testing.T) { err = ipArg.Set("not an ip") test.Err(t, err) - test.Equal( - t, - err.Error(), - `arg "ip" received invalid value "not an ip" (expected net.IP), detail: invalid IP address`, - ) + test.True(t, errors.Is(err, parse.Err)) }) } diff --git a/internal/parse/parse.go b/internal/parse/parse.go new file mode 100644 index 0000000..cb6c9ab --- /dev/null +++ b/internal/parse/parse.go @@ -0,0 +1,156 @@ +// Package parse provides functions to parse strings into Go types and produce +// detailed, consistent errors. +// +// It is used across both internal/flag and internal/arg to provide consistency. +package parse + +import ( + "errors" + "fmt" + "strconv" +) + +// Kind is the kind of parsing being done, either argument or flag. +type Kind string + +// Err is a generic parse error. +// +// Errors returned from the [Error] function will match this in a call +// to [errors.Is]. +var Err = errors.New("parse error") + +const ( + // KindArgument is the [Kind] used for argument parsing. + KindArgument Kind = "argument" + + // KindFlag is the [Kind] used for flag parsing. + KindFlag Kind = "flag" +) + +const ( + bits8 = 8 << iota // 8 bit integer + bits16 // 16 bit integer + bits32 // 32 bit integer + bits64 // 64 bit integer +) + +const base10 = 10 + +// Error produces a formatted parse error. +// +// The kind should must be [KindArgument] or [KindFlag], with name and str being the +// name of the arg/flag and the invalid text that triggered the error. +// +// The type T is the type we were parsing str into and err is any underlying +// error e.g. from strconv. +// +// // Make a flag parse error +// var force bool +// return parse.Error(parse.KindArgument, "force", "faklse", force, strconv.ErrSyntax) +func Error[T any](kind Kind, name, str string, typ T, err error) error { + // Ordinarily I wouldn't have a package like this concern itself with + // details of other packages (like flag/arg) but given this package exists to produce consistent + // behaviour and clear error messages in the narrow context of this cli framework, then + // it makes sense the error is defined here too. + return fmt.Errorf("%w: %s %q received invalid value %q (expected %T): %w", Err, kind, name, str, typ, err) +} + +// Int parses an int from a string. +func Int(str string) (int, error) { + val, err := strconv.ParseInt(str, base10, 0) + if err != nil { + return 0, err + } + + return int(val), nil +} + +// Int8 parses an int8 from a string. +func Int8(str string) (int8, error) { + val, err := strconv.ParseInt(str, base10, bits8) + if err != nil { + return 0, err + } + + return int8(val), nil +} + +// Int16 parses an int16 from a string. +func Int16(str string) (int16, error) { + val, err := strconv.ParseInt(str, base10, bits16) + if err != nil { + return 0, err + } + + return int16(val), nil +} + +// Int32 parses an int32 from a string. +func Int32(str string) (int32, error) { + val, err := strconv.ParseInt(str, base10, bits32) + if err != nil { + return 0, err + } + + return int32(val), nil +} + +// Int64 parses an int64 from a string. +func Int64(str string) (int64, error) { + val, err := strconv.ParseInt(str, base10, bits64) + if err != nil { + return 0, err + } + + return val, nil +} + +// Uint parses a uint from a string. +func Uint(str string) (uint, error) { + val, err := strconv.ParseUint(str, base10, 0) + if err != nil { + return 0, err + } + + return uint(val), nil +} + +// Uint8 parses an uint8 from a string. +func Uint8(str string) (uint8, error) { + val, err := strconv.ParseUint(str, base10, bits8) + if err != nil { + return 0, err + } + + return uint8(val), nil +} + +// Uint16 parses an uint16 from a string. +func Uint16(str string) (uint16, error) { + val, err := strconv.ParseUint(str, base10, bits16) + if err != nil { + return 0, err + } + + return uint16(val), nil +} + +// Uint32 parses an uint32 from a string. +func Uint32(str string) (uint32, error) { + val, err := strconv.ParseUint(str, base10, bits32) + if err != nil { + return 0, err + } + + return uint32(val), nil +} + +// Uint64 parses an uint64 from a string. +func Uint64(str string) (uint64, error) { + val, err := strconv.ParseUint(str, base10, bits64) + if err != nil { + return 0, err + } + + return val, nil +} diff --git a/internal/parse/parse_test.go b/internal/parse/parse_test.go new file mode 100644 index 0000000..a70de61 --- /dev/null +++ b/internal/parse/parse_test.go @@ -0,0 +1,141 @@ +package parse //nolint:testpackage // I need the base and bits values and don't want to export them. + +import ( + "strconv" + "testing" + "testing/quick" +) + +// These are basically all just testing that I haven't broken anything +// by wrapping strconv and saves me having to write lots of test cases +// by hand. + +func TestInt(t *testing.T) { + test := Int + + reference := func(str string) (int, error) { + val, err := strconv.ParseInt(str, base10, 0) + return int(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestInt8(t *testing.T) { + test := Int8 + + reference := func(str string) (int8, error) { + val, err := strconv.ParseInt(str, base10, bits8) + return int8(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestInt16(t *testing.T) { + test := Int16 + + reference := func(str string) (int16, error) { + val, err := strconv.ParseInt(str, base10, bits16) + return int16(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestInt32(t *testing.T) { + test := Int32 + + reference := func(str string) (int32, error) { + val, err := strconv.ParseInt(str, base10, bits32) + return int32(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestInt64(t *testing.T) { + test := Int64 + + reference := func(str string) (int64, error) { + val, err := strconv.ParseInt(str, base10, bits64) + return val, err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestUint(t *testing.T) { + test := Uint + + reference := func(str string) (uint, error) { + val, err := strconv.ParseUint(str, base10, 0) + return uint(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestUint8(t *testing.T) { + test := Uint8 + + reference := func(str string) (uint8, error) { + val, err := strconv.ParseUint(str, base10, bits8) + return uint8(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestUint16(t *testing.T) { + test := Uint16 + + reference := func(str string) (uint16, error) { + val, err := strconv.ParseUint(str, base10, bits16) + return uint16(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestUint32(t *testing.T) { + test := Uint32 + + reference := func(str string) (uint32, error) { + val, err := strconv.ParseUint(str, base10, bits32) + return uint32(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestUint64(t *testing.T) { + test := Uint64 + + reference := func(str string) (uint64, error) { + val, err := strconv.ParseUint(str, base10, bits64) + return val, err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} From fa5643073ab0a9e6d9d573e90adef22487de2caf Mon Sep 17 00:00:00 2001 From: Tom Fleet Date: Sun, 9 Nov 2025 09:33:51 +0000 Subject: [PATCH 2/4] Add cast to the parse package --- internal/arg/arg.go | 81 ++++++++++-------------------------- internal/parse/parse.go | 45 ++++++++++++++++++++ internal/parse/parse_test.go | 26 ++++++++++++ 3 files changed, 92 insertions(+), 60 deletions(-) diff --git a/internal/arg/arg.go b/internal/arg/arg.go index 237fe2d..9a31d89 100644 --- a/internal/arg/arg.go +++ b/internal/arg/arg.go @@ -12,7 +12,6 @@ import ( "strings" "time" "unicode" - "unsafe" "go.followtheprocess.codes/cli/arg" "go.followtheprocess.codes/cli/internal/constraints" @@ -266,7 +265,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case int8: @@ -275,7 +274,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case int16: @@ -284,7 +283,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case int32: @@ -293,7 +292,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case int64: @@ -302,7 +301,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case uint: @@ -311,7 +310,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case uint8: @@ -320,7 +319,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case uint16: @@ -329,7 +328,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case uint32: @@ -338,7 +337,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case uint64: @@ -347,7 +346,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case uintptr: @@ -356,30 +355,30 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case float32: - val, err := parseFloat[float32](bits32)(str) + val, err := parse.Float32(str) if err != nil { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case float64: - val, err := parseFloat[float64](bits64)(str) + val, err := parse.Float64(str) if err != nil { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case string: val := str - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case bool: @@ -388,7 +387,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case []byte: @@ -397,7 +396,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case time.Time: @@ -406,7 +405,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case time.Duration: @@ -415,7 +414,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, err) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil case net.IP: @@ -424,7 +423,7 @@ func (a Arg[T]) Set(str string) error { return parse.Error(parse.KindArgument, a.name, str, typ, errors.New("invalid IP address")) } - *a.value = *cast[T](&val) + *a.value = *parse.Cast[T](&val) return nil default: @@ -491,41 +490,3 @@ func formatFloat[T ~float32 | ~float64](bits int) func(T) string { return strconv.FormatFloat(float64(in), 'g', -1, bits) } } - -// cast converts a *T1 to a *T2, we use it here when we know (via generics and compile time checks) -// that e.g. the Flag.value is a string, but we can't directly do Flag.value = "value" because -// we can't assign a string to a generic 'T', but we *know* that the value *is* a string because when -// instantiating a Flag[T], you have to provide (or compiler has to infer) Flag[string]. -// -// # Safety -// -// This function uses [unsafe.Pointer] underneath to reassign the types but we know this is safe to do -// based on the compile time checks provided by generics. Further, it fits the following valid pattern -// specified in the docs for [unsafe.Pointer]. -// -// Conversion of a *T1 to Pointer to *T2 -// -// Provided that T2 is no larger than T1 and that the two share an equivalent -// memory layout, this conversion allows reinterpreting data of one type as -// data of another type. -// -// This describes our use case as we're converting a *T to e.g a *string but *only* when we know -// that a Flag[T] is actually Flag[string], so the memory layout and size is guaranteed by the -// compiler to be equivalent. -func cast[T2, T1 any](v *T1) *T2 { - return (*T2)(unsafe.Pointer(v)) -} - -// parseFloat is a generic helper to parse floating point numbers, given a bit size. -// -// It returns the parsed value or an error. -func parseFloat[T ~float32 | ~float64](bits int) func(str string) (T, error) { - return func(str string) (T, error) { - val, err := strconv.ParseFloat(str, bits) - if err != nil { - return 0, err - } - - return T(val), nil - } -} diff --git a/internal/parse/parse.go b/internal/parse/parse.go index cb6c9ab..8549c57 100644 --- a/internal/parse/parse.go +++ b/internal/parse/parse.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "strconv" + "unsafe" ) // Kind is the kind of parsing being done, either argument or flag. @@ -154,3 +155,47 @@ func Uint64(str string) (uint64, error) { return val, nil } + +// Float32 parses a float32 from a string. +func Float32(str string) (float32, error) { + val, err := strconv.ParseFloat(str, bits32) + if err != nil { + return 0, err + } + + return float32(val), nil +} + +// Float64 parses a float64 from a string. +func Float64(str string) (float64, error) { + val, err := strconv.ParseFloat(str, bits64) + if err != nil { + return 0, err + } + + return float64(val), nil +} + +// Cast converts a *T1 to a *T2, we use it here when we know (via generics and compile time checks) +// that e.g. the Flag.value is a string, but we can't directly do Flag.value = "value" because +// we can't assign a string to a generic 'T', but we *know* that the value *is* a string because when +// instantiating a Flag[T], you have to provide (or compiler has to infer) Flag[string]. +// +// # Safety +// +// This function uses [unsafe.Pointer] underneath to reassign the types but we know this is safe to do +// based on the compile time checks provided by generics. Further, it fits the following valid pattern +// specified in the docs for [unsafe.Pointer]. +// +// Conversion of a *T1 to Pointer to *T2 +// +// Provided that T2 is no larger than T1 and that the two share an equivalent +// memory layout, this conversion allows reinterpreting data of one type as +// data of another type. +// +// This describes our use case as we're converting a *T to e.g a *string but *only* when we know +// that a Flag[T] is actually Flag[string], so the memory layout and size is guaranteed by the +// compiler to be equivalent. +func Cast[T2, T1 any](v *T1) *T2 { + return (*T2)(unsafe.Pointer(v)) +} diff --git a/internal/parse/parse_test.go b/internal/parse/parse_test.go index a70de61..2a46a5c 100644 --- a/internal/parse/parse_test.go +++ b/internal/parse/parse_test.go @@ -139,3 +139,29 @@ func TestUint64(t *testing.T) { t.Error(err) } } + +func TestFloat32(t *testing.T) { + test := Float32 + + reference := func(str string) (float32, error) { + val, err := strconv.ParseFloat(str, bits32) + return float32(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestFloat64(t *testing.T) { + test := Float64 + + reference := func(str string) (float64, error) { + val, err := strconv.ParseFloat(str, bits64) + return float64(val), err + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} From 482a30555888cad9e9bf5f9d3d206c0279e14796 Mon Sep 17 00:00:00 2001 From: Tom Fleet Date: Sun, 9 Nov 2025 09:49:24 +0000 Subject: [PATCH 3/4] Add a format package to do the inverse --- internal/arg/arg.go | 79 ++++++++++++---------------------- internal/format/format.go | 41 ++++++++++++++++++ internal/format/format_test.go | 61 ++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 52 deletions(-) create mode 100644 internal/format/format.go create mode 100644 internal/format/format_test.go diff --git a/internal/arg/arg.go b/internal/arg/arg.go index 9a31d89..a5578d5 100644 --- a/internal/arg/arg.go +++ b/internal/arg/arg.go @@ -14,7 +14,7 @@ import ( "unicode" "go.followtheprocess.codes/cli/arg" - "go.followtheprocess.codes/cli/internal/constraints" + "go.followtheprocess.codes/cli/internal/format" "go.followtheprocess.codes/cli/internal/parse" ) @@ -22,14 +22,6 @@ import ( // Once we know this is the direction to go down, then we should combine all the shared // stuff and use it from each package -const ( - _ = 4 << iota // Unused - bits8 // 8 bit integer - bits16 // 16 bit integer - bits32 // 32 bit integer - bits64 // 64 bit integer -) - const ( typeInt = "int" typeInt8 = "int8" @@ -104,31 +96,31 @@ func (a Arg[T]) Default() string { switch typ := any(*a.config.DefaultValue).(type) { case int: - return formatInt(typ) + return format.Int(typ) case int8: - return formatInt(typ) + return format.Int(typ) case int16: - return formatInt(typ) + return format.Int(typ) case int32: - return formatInt(typ) + return format.Int(typ) case int64: - return formatInt(typ) + return format.Int(typ) case uint: - return formatUint(typ) + return format.Uint(typ) case uint8: - return formatUint(typ) + return format.Uint(typ) case uint16: - return formatUint(typ) + return format.Uint(typ) case uint32: - return formatUint(typ) + return format.Uint(typ) case uint64: - return formatUint(typ) + return format.Uint(typ) case uintptr: - return formatUint(typ) + return format.Uint(typ) case float32: - return formatFloat[float32](bits32)(typ) + return format.Float32(typ) case float64: - return formatFloat[float64](bits64)(typ) + return format.Float64(typ) case string: return typ case bool: @@ -156,31 +148,31 @@ func (a Arg[T]) String() string { switch typ := any(*a.value).(type) { case int: - return formatInt(typ) + return format.Int(typ) case int8: - return formatInt(typ) + return format.Int(typ) case int16: - return formatInt(typ) + return format.Int(typ) case int32: - return formatInt(typ) + return format.Int(typ) case int64: - return formatInt(typ) + return format.Int(typ) case uint: - return formatUint(typ) + return format.Uint(typ) case uint8: - return formatUint(typ) + return format.Uint(typ) case uint16: - return formatUint(typ) + return format.Uint(typ) case uint32: - return formatUint(typ) + return format.Uint(typ) case uint64: - return formatUint(typ) + return format.Uint(typ) case uintptr: - return formatUint(typ) + return format.Uint(typ) case float32: - return formatFloat[float32](bits32)(typ) + return format.Float32(typ) case float64: - return formatFloat[float64](bits64)(typ) + return format.Float64(typ) case string: return typ case bool: @@ -473,20 +465,3 @@ func validateArgName(name string) error { return nil } - -// formatInt is a generic helper to return a string representation of any signed integer. -func formatInt[T constraints.Signed](in T) string { - return strconv.FormatInt(int64(in), 10) -} - -// formatUint is a generic helper to return a string representation of any unsigned integer. -func formatUint[T constraints.Unsigned](in T) string { - return strconv.FormatUint(uint64(in), 10) -} - -// formatFloat is a generic helper to return a string representation of any floating point digit. -func formatFloat[T ~float32 | ~float64](bits int) func(T) string { - return func(in T) string { - return strconv.FormatFloat(float64(in), 'g', -1, bits) - } -} diff --git a/internal/format/format.go b/internal/format/format.go new file mode 100644 index 0000000..0f4d540 --- /dev/null +++ b/internal/format/format.go @@ -0,0 +1,41 @@ +// Package format is the inverse of parse. +// +// It formats arg/flag values as string representations. +package format + +import ( + "strconv" + + "go.followtheprocess.codes/cli/internal/constraints" +) + +const ( + base10 = 10 + floatFmt = 'g' + floatPrecision = -1 +) + +const ( + bits32 = 32 << iota + bits64 +) + +// Int returns a string representation of an integer. +func Int[T constraints.Signed](n T) string { + return strconv.FormatInt(int64(n), base10) +} + +// Uint returns a string representation of an unsigned integer. +func Uint[T constraints.Unsigned](n T) string { + return strconv.FormatUint(uint64(n), base10) +} + +// Float32 returns a string representation of a float32. +func Float32(f float32) string { + return strconv.FormatFloat(float64(f), floatFmt, floatPrecision, bits32) +} + +// Float64 returns a string representation of a float64. +func Float64(f float64) string { + return strconv.FormatFloat(float64(f), floatFmt, floatPrecision, bits64) +} diff --git a/internal/format/format_test.go b/internal/format/format_test.go new file mode 100644 index 0000000..8fa7eb5 --- /dev/null +++ b/internal/format/format_test.go @@ -0,0 +1,61 @@ +package format //nolint:testpackage // I need the base and bits values and don't want to export them. + +import ( + "strconv" + "testing" + "testing/quick" +) + +func TestInt(t *testing.T) { + //nolint:gocritic // It wants me to "unlambda" this but it's generic so I can't + test := func(n int) string { + return Int(n) + } + + reference := func(n int) string { + return strconv.FormatInt(int64(n), base10) + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestUint(t *testing.T) { + //nolint:gocritic // It wants me to "unlambda" this but it's generic so I can't + test := func(n uint) string { + return Uint(n) + } + + reference := func(n uint) string { + return strconv.FormatUint(uint64(n), base10) + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestFloat32(t *testing.T) { + test := Float32 + + reference := func(f float32) string { + return strconv.FormatFloat(float64(f), floatFmt, floatPrecision, bits32) + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} + +func TestFloat64(t *testing.T) { + test := Float64 + + reference := func(f float64) string { + return strconv.FormatFloat(float64(f), floatFmt, floatPrecision, bits64) + } + + if err := quick.CheckEqual(test, reference, nil); err != nil { + t.Error(err) + } +} From bed0cf935273d3998540a45db9bc96cdc3b257d1 Mon Sep 17 00:00:00 2001 From: Tom Fleet Date: Sun, 9 Nov 2025 09:58:21 +0000 Subject: [PATCH 4/4] Add type names to the format package --- internal/arg/arg.go | 64 ++++++++++++--------------------------- internal/format/format.go | 23 ++++++++++++++ 2 files changed, 42 insertions(+), 45 deletions(-) diff --git a/internal/arg/arg.go b/internal/arg/arg.go index a5578d5..393bbb8 100644 --- a/internal/arg/arg.go +++ b/internal/arg/arg.go @@ -18,32 +18,6 @@ import ( "go.followtheprocess.codes/cli/internal/parse" ) -// TODO(@FollowTheProcess): LOTS of duplicated stuff with internal/flag. -// Once we know this is the direction to go down, then we should combine all the shared -// stuff and use it from each package - -const ( - typeInt = "int" - typeInt8 = "int8" - typeInt16 = "int16" - typeInt32 = "int32" - typeInt64 = "int64" - typeUint = "uint" - typeUint8 = "uint8" - typeUint16 = "uint16" - typeUint32 = "uint32" - typeUint64 = "uint64" - typeUintptr = "uintptr" - typeFloat32 = "float32" - typeFloat64 = "float64" - typeString = "string" - typeBool = "bool" - typeBytesHex = "bytesHex" - typeTime = "time" - typeDuration = "duration" - typeIP = "ip" -) - var _ Value = Arg[string]{} // This will fail if we violate our Value interface // Arg represents a single command line argument. @@ -200,43 +174,43 @@ func (a Arg[T]) Type() string { switch typ := any(*a.value).(type) { case int: - return typeInt + return format.TypeInt case int8: - return typeInt8 + return format.TypeInt8 case int16: - return typeInt16 + return format.TypeInt16 case int32: - return typeInt32 + return format.TypeInt32 case int64: - return typeInt64 + return format.TypeInt64 case uint: - return typeUint + return format.TypeUint case uint8: - return typeUint8 + return format.TypeUint8 case uint16: - return typeUint16 + return format.TypeUint16 case uint32: - return typeUint32 + return format.TypeUint32 case uint64: - return typeUint64 + return format.TypeUint64 case uintptr: - return typeUintptr + return format.TypeUintptr case float32: - return typeFloat32 + return format.TypeFloat32 case float64: - return typeFloat64 + return format.TypeFloat64 case string: - return typeString + return format.TypeString case bool: - return typeBool + return format.TypeBool case []byte: - return typeBytesHex + return format.TypeBytesHex case time.Time: - return typeTime + return format.TypeTime case time.Duration: - return typeDuration + return format.TypeDuration case net.IP: - return typeIP + return format.TypeIP default: return fmt.Sprintf("%T", typ) } diff --git a/internal/format/format.go b/internal/format/format.go index 0f4d540..6be9a87 100644 --- a/internal/format/format.go +++ b/internal/format/format.go @@ -20,6 +20,29 @@ const ( bits64 ) +// Type names. +const ( + TypeInt = "int" + TypeInt8 = "int8" + TypeInt16 = "int16" + TypeInt32 = "int32" + TypeInt64 = "int64" + TypeUint = "uint" + TypeUint8 = "uint8" + TypeUint16 = "uint16" + TypeUint32 = "uint32" + TypeUint64 = "uint64" + TypeUintptr = "uintptr" + TypeFloat32 = "float32" + TypeFloat64 = "float64" + TypeString = "string" + TypeBool = "bool" + TypeBytesHex = "bytesHex" + TypeTime = "time" + TypeDuration = "duration" + TypeIP = "ip" +) + // Int returns a string representation of an integer. func Int[T constraints.Signed](n T) string { return strconv.FormatInt(int64(n), base10)