Skip to content

Commit 9af56a1

Browse files
committed
support scanning into all numeric types
Use reflection to support scanning into any numeric types, including named types that have an underlying numeric type.
1 parent d959cf1 commit 9af56a1

1 file changed

Lines changed: 69 additions & 74 deletions

File tree

datadriven.go

Lines changed: 69 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"io/ioutil"
2222
"os"
2323
"path/filepath"
24+
"reflect"
2425
"regexp"
2526
"runtime"
2627
"strconv"
@@ -743,41 +744,19 @@ func (arg CmdArg) scan(t testing.TB, pos string, dests ...interface{}) {
743744
}
744745

745746
func (arg CmdArg) scanAllErr(dest interface{}) error {
746-
// Try supported slice destination types.
747-
switch dest := dest.(type) {
748-
case *[]string:
749-
// Make a copy to avoid unexpected mutation of CmdArg.Vals.
750-
*dest = append([]string(nil), arg.Vals...)
751-
return nil
752-
case *[]int:
753-
*dest = make([]int, len(arg.Vals))
754-
for i := 0; i < len(arg.Vals); i++ {
755-
n, err := strconv.ParseInt(arg.Vals[i], 10, 64)
756-
if err != nil {
757-
return fmt.Errorf("arg %d: %w", i, err)
758-
}
759-
(*dest)[i] = int(n)
760-
}
761-
return nil
762-
case *[]uint64:
763-
*dest = make([]uint64, len(arg.Vals))
764-
for i := 0; i < len(arg.Vals); i++ {
765-
n, err := strconv.ParseUint(arg.Vals[i], 10, 64)
766-
if err != nil {
767-
return fmt.Errorf("arg %d: %w", i, err)
768-
}
769-
(*dest)[i] = uint64(n)
770-
}
771-
return nil
772-
case *[]float64:
773-
*dest = make([]float64, len(arg.Vals))
774-
for i := 0; i < len(arg.Vals); i++ {
775-
n, err := strconv.ParseFloat(arg.Vals[i], 64)
776-
if err != nil {
777-
return fmt.Errorf("arg %d: %w", i, err)
747+
rv := reflect.ValueOf(dest)
748+
if rv.Kind() != reflect.Ptr || rv.IsNil() {
749+
return fmt.Errorf("out must be a non-nil pointer to a slice")
750+
}
751+
752+
if sliceV := rv.Elem(); sliceV.Kind() == reflect.Slice {
753+
slice := reflect.MakeSlice(sliceV.Type(), len(arg.Vals), len(arg.Vals))
754+
for i := range arg.Vals {
755+
if err := arg.scanScalarErr(i, slice.Index(i).Addr().Interface()); err != nil {
756+
return err
778757
}
779-
(*dest)[i] = float64(n)
780758
}
759+
sliceV.Set(slice)
781760
return nil
782761
}
783762

@@ -795,55 +774,23 @@ func (arg CmdArg) scanScalarErr(i int, dest interface{}) error {
795774
return fmt.Errorf("cannot scan index %d of key %s", i, arg.Key)
796775
}
797776
val := arg.Vals[i]
777+
778+
// Special cases.
798779
switch dest := dest.(type) {
799-
case *string:
800-
*dest = val
801-
case *int:
802-
n, err := strconv.ParseInt(val, 10, 64)
803-
if err != nil {
804-
return err
805-
}
806-
*dest = int(n) // assume 64bit ints
807-
case *int64:
808-
n, err := strconv.ParseInt(val, 10, 64)
809-
if err != nil {
810-
return err
811-
}
812-
*dest = n
813-
case *uint64:
814-
n, err := strconv.ParseUint(val, 10, 64)
815-
if err != nil {
816-
return err
817-
}
818-
*dest = n
819-
case *uint32:
820-
n, err := strconv.ParseUint(val, 10, 32)
821-
if err != nil {
822-
return err
823-
}
824-
*dest = uint32(n)
825-
case *bool:
826-
b, err := strconv.ParseBool(val)
827-
if err != nil {
828-
return err
829-
}
830-
*dest = b
831-
case *float64:
832-
t, err := strconv.ParseFloat(val, 64)
833-
if err != nil {
834-
return err
835-
}
836-
*dest = t
837780
case *time.Duration:
838781
t, err := time.ParseDuration(val)
839782
if err != nil {
840783
return err
841784
}
842785
*dest = t
843-
default:
844-
return fmt.Errorf("unsupported type %T for destination #%d (might be easy to add it)", dest, i+1)
786+
return nil
845787
}
846-
return nil
788+
789+
rv := reflect.ValueOf(dest)
790+
if rv.Kind() != reflect.Ptr || rv.IsNil() {
791+
return fmt.Errorf("out must be a non-nil pointer")
792+
}
793+
return parseArgVal(val, rv.Elem())
847794
}
848795

849796
// Logf is a wrapper for tb.Logf which adds file position information, so
@@ -897,3 +844,51 @@ func indentLines(str string) string {
897844
}
898845
return b.String()
899846
}
847+
848+
// parseArgVal parses s and stores it into dest, which must have type int*,
849+
// uint*, float*, string, bool, or a named type with one of those underlying
850+
// kinds.
851+
//
852+
// Uses base 0 for ints/uints (supports 0x, 0o, 0b).
853+
func parseArgVal(s string, dest reflect.Value) error {
854+
t := dest.Type()
855+
856+
s = strings.TrimSpace(s)
857+
switch t.Kind() {
858+
case reflect.String:
859+
dest.SetString(s)
860+
861+
case reflect.Bool:
862+
b, err := strconv.ParseBool(s)
863+
if err != nil {
864+
return err
865+
}
866+
dest.SetBool(b)
867+
868+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
869+
n, err := strconv.ParseInt(s, 0, int(t.Bits()))
870+
if err != nil {
871+
return fmt.Errorf("parse %q as %s: %w", s, t, err)
872+
}
873+
dest.SetInt(n)
874+
875+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
876+
n, err := strconv.ParseUint(s, 0, int(t.Bits()))
877+
if err != nil {
878+
return fmt.Errorf("parse %q as %s: %w", s, t, err)
879+
}
880+
dest.SetUint(n)
881+
882+
case reflect.Float32, reflect.Float64:
883+
f, err := strconv.ParseFloat(s, int(t.Bits()))
884+
if err != nil {
885+
return fmt.Errorf("parse %q as %s: %w", s, t, err)
886+
}
887+
dest.SetFloat(f)
888+
889+
default:
890+
return fmt.Errorf("dest must point to a numeric type; got %s", t)
891+
}
892+
893+
return nil
894+
}

0 commit comments

Comments
 (0)