diff --git a/duration.go b/duration.go index 832ce5d..b690b50 100644 --- a/duration.go +++ b/duration.go @@ -8,7 +8,6 @@ import ( "strconv" "strings" "time" - "unicode" ) // Duration holds all the smaller units that make up the duration @@ -51,33 +50,36 @@ var ( // Parse attempts to parse the given duration string into a *Duration, // if parsing fails an error is returned instead. func Parse(d string) (*Duration, error) { - state := parsingPeriod - duration := &Duration{} - num := "" - var err error + var ( + state = parsingPeriod + duration = &Duration{} + num string + err error + rank = 8 // designator order, strictly descending: Y=7 M=6 W=5 D=4 H=3 M=2 S=1 + ) switch { case strings.HasPrefix(d, "P"): // standard duration + d = d[1:] case strings.HasPrefix(d, "-P"): // negative duration duration.Negative = true - d = strings.TrimPrefix(d, "-") // remove the negative sign + d = d[2:] default: return nil, ErrUnexpectedInput } for _, char := range d { switch char { - case 'P': - if state != parsingPeriod { + case 'T': + if state == parsingTime || num != "" { return nil, ErrUnexpectedInput } - case 'T': state = parsingTime case 'Y': - if state != parsingPeriod { + if state != parsingPeriod || rank <= 7 { return nil, ErrUnexpectedInput } - + rank = 7 duration.Years, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err @@ -85,12 +87,20 @@ func Parse(d string) (*Duration, error) { num = "" case 'M': if state == parsingPeriod { + if rank <= 6 { + return nil, ErrUnexpectedInput + } + rank = 6 duration.Months, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err } num = "" - } else if state == parsingTime { + } else { + if rank <= 2 { + return nil, ErrUnexpectedInput + } + rank = 2 duration.Minutes, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err @@ -98,57 +108,62 @@ func Parse(d string) (*Duration, error) { num = "" } case 'W': - if state != parsingPeriod { + if state != parsingPeriod || rank <= 5 { return nil, ErrUnexpectedInput } - + rank = 5 duration.Weeks, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err } num = "" case 'D': - if state != parsingPeriod { + if state != parsingPeriod || rank <= 4 { return nil, ErrUnexpectedInput } - + rank = 4 duration.Days, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err } num = "" case 'H': - if state != parsingTime { + if state != parsingTime || rank <= 3 { return nil, ErrUnexpectedInput } - + rank = 3 duration.Hours, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err } num = "" case 'S': - if state != parsingTime { + if state != parsingTime || rank <= 1 { return nil, ErrUnexpectedInput } - + rank = 1 duration.Seconds, err = strconv.ParseFloat(num, 64) if err != nil { return nil, err } num = "" default: - if unicode.IsNumber(char) || char == '.' { + if (char >= '0' && char <= '9') || char == '.' { num += string(char) continue } - return nil, ErrUnexpectedInput } } if num != "" { return nil, ErrIncompleteExpr } + if state == parsingPeriod && rank == 8 { + return nil, ErrIncompleteExpr + } + if state == parsingTime && rank > 3 { + return nil, ErrIncompleteExpr + } return duration, nil } diff --git a/duration_test.go b/duration_test.go index 20e29e0..4cd642a 100644 --- a/duration_test.go +++ b/duration_test.go @@ -83,6 +83,12 @@ func TestParse(t *testing.T) { }, errorMatchFn: noError, }, + { + name: "bare P with no components", + args: args{d: "P"}, + want: nil, + errorMatchFn: newMatchFn(ErrIncompleteExpr), + }, { name: "no unit after prefix P", args: args{d: "P6"}, @@ -95,6 +101,84 @@ func TestParse(t *testing.T) { want: nil, errorMatchFn: newMatchFn(ErrIncompleteExpr), }, + { + name: "double P at start", + args: args{d: "PP1D"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "trailing P in period section", + args: args{d: "P1DP"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "double T", + args: args{d: "PTT1H"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "T between time designators", + args: args{d: "PT1HT1M"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "duplicate H designator", + args: args{d: "PT1H2H"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "duplicate Y designator", + args: args{d: "P1Y2Y"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "duplicate M designator in period", + args: args{d: "P1M2M"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "duplicate M designator in time", + args: args{d: "PT1M2M"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "time designators out of order M before H", + args: args{d: "PT1M1H"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "period designators out of order D before Y", + args: args{d: "P1D1Y"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "bare PT with no time components", + args: args{d: "PT"}, + want: nil, + errorMatchFn: newMatchFn(ErrIncompleteExpr), + }, + { + name: "T at end after valid time component", + args: args{d: "PT2HT"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, + { + name: "non-ASCII digit", + args: args{d: "P٥Y"}, + want: nil, + errorMatchFn: newMatchFn(ErrUnexpectedInput), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {