Use interface{} instead of string in tokens.
This allows the scanner and resolvers to pass Go types around rather than having to serialise/deserialise to/from strings.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package kong
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"math/bits"
|
||||
"net/url"
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -194,9 +196,8 @@ func (r *Registry) RegisterDefaults() *Registry {
|
||||
RegisterKind(reflect.Float32, floatDecoder(32)).
|
||||
RegisterKind(reflect.Float64, floatDecoder(64)).
|
||||
RegisterKind(reflect.String, MapperFunc(func(ctx *DecodeContext, target reflect.Value) error {
|
||||
token := ctx.Scan.PopValue("string")
|
||||
target.SetString(token)
|
||||
return nil
|
||||
_, err := ctx.Scan.PopValueInto("string", target.Addr().Interface())
|
||||
return err
|
||||
})).
|
||||
RegisterKind(reflect.Bool, boolMapper{}).
|
||||
RegisterKind(reflect.Slice, sliceDecoder(r)).
|
||||
@@ -214,7 +215,16 @@ type boolMapper struct{}
|
||||
func (boolMapper) Decode(ctx *DecodeContext, target reflect.Value) error {
|
||||
if ctx.Scan.Peek().Type == FlagValueToken {
|
||||
token := ctx.Scan.Pop()
|
||||
target.SetBool(token.Value == "true")
|
||||
switch v := token.Value.(type) {
|
||||
case string:
|
||||
target.SetBool(strings.ToLower(v) == "true")
|
||||
|
||||
case bool:
|
||||
target.SetBool(v)
|
||||
|
||||
default:
|
||||
return errors.Errorf("expected bool but got %q (%T)", token.Value, token.Value)
|
||||
}
|
||||
} else {
|
||||
target.SetBool(true)
|
||||
}
|
||||
@@ -224,10 +234,13 @@ func (boolMapper) IsBool() bool { return true }
|
||||
|
||||
func durationDecoder() MapperFunc {
|
||||
return func(ctx *DecodeContext, target reflect.Value) error {
|
||||
value := ctx.Scan.PopValue("duration")
|
||||
var value string
|
||||
if _, err := ctx.Scan.PopValueInto("duration", &value); err != nil {
|
||||
return err
|
||||
}
|
||||
r, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expected duration but got %q: %s", value, err)
|
||||
return errors.Errorf("expected duration but got %q: %s", value, err)
|
||||
}
|
||||
target.Set(reflect.ValueOf(r))
|
||||
return nil
|
||||
@@ -240,7 +253,10 @@ func timeDecoder() MapperFunc {
|
||||
if ctx.Value.Format != "" {
|
||||
format = ctx.Value.Format
|
||||
}
|
||||
value := ctx.Scan.PopValue("time")
|
||||
var value string
|
||||
if _, err := ctx.Scan.PopValueInto("time", &value); err != nil {
|
||||
return err
|
||||
}
|
||||
t, err := time.Parse(format, value)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -250,38 +266,86 @@ func timeDecoder() MapperFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func intDecoder(bits int) MapperFunc {
|
||||
func intDecoder(bits int) MapperFunc { // nolint: dupl
|
||||
return func(ctx *DecodeContext, target reflect.Value) error {
|
||||
value := ctx.Scan.PopValue("int")
|
||||
n, err := strconv.ParseInt(value, 10, bits)
|
||||
t, err := ctx.Scan.PopValue("int")
|
||||
if err != nil {
|
||||
return fmt.Errorf("expected int but got %q", value)
|
||||
return err
|
||||
}
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
n, err := strconv.ParseInt(v, 10, bits)
|
||||
if err != nil {
|
||||
return errors.Errorf("expected an int but got %q (%T)", t, t.Value)
|
||||
}
|
||||
target.SetInt(n)
|
||||
|
||||
case float64:
|
||||
target.SetInt(int64(v))
|
||||
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
target.Set(reflect.ValueOf(v))
|
||||
|
||||
default:
|
||||
return errors.Errorf("expected an int but got %q (%T)", t, t.Value)
|
||||
}
|
||||
target.SetInt(n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func uintDecoder(bits int) MapperFunc {
|
||||
func uintDecoder(bits int) MapperFunc { // nolint: dupl
|
||||
return func(ctx *DecodeContext, target reflect.Value) error {
|
||||
value := ctx.Scan.PopValue("uint")
|
||||
n, err := strconv.ParseUint(value, 10, bits)
|
||||
t, err := ctx.Scan.PopValue("uint")
|
||||
if err != nil {
|
||||
return fmt.Errorf("expected unsigned int but got %q", value)
|
||||
return err
|
||||
}
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
n, err := strconv.ParseUint(v, 10, bits)
|
||||
if err != nil {
|
||||
return errors.Errorf("expected a uint but got %q (%T)", t, t.Value)
|
||||
}
|
||||
target.SetUint(n)
|
||||
|
||||
case float64:
|
||||
target.SetUint(uint64(v))
|
||||
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
target.Set(reflect.ValueOf(v))
|
||||
|
||||
default:
|
||||
return errors.Errorf("expected an int but got %q (%T)", t, t.Value)
|
||||
}
|
||||
target.SetUint(n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func floatDecoder(bits int) MapperFunc {
|
||||
return func(ctx *DecodeContext, target reflect.Value) error {
|
||||
value := ctx.Scan.PopValue("float")
|
||||
n, err := strconv.ParseFloat(value, bits)
|
||||
t, err := ctx.Scan.PopValue("float")
|
||||
if err != nil {
|
||||
return fmt.Errorf("expected float but got %q", value)
|
||||
return err
|
||||
}
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
n, err := strconv.ParseFloat(v, bits)
|
||||
if err != nil {
|
||||
return errors.Errorf("expected a float but got %q (%T)", t, t.Value)
|
||||
}
|
||||
target.SetFloat(n)
|
||||
|
||||
case float32:
|
||||
target.SetFloat(float64(v))
|
||||
|
||||
case float64:
|
||||
target.SetFloat(v)
|
||||
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
target.Set(reflect.ValueOf(v))
|
||||
|
||||
default:
|
||||
return errors.Errorf("expected an int but got %q (%T)", t, t.Value)
|
||||
}
|
||||
target.SetFloat(n)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -294,17 +358,43 @@ func mapDecoder(r *Registry) MapperFunc {
|
||||
el := target.Type()
|
||||
var childScanner *Scanner
|
||||
if ctx.Value.Flag != nil {
|
||||
t := ctx.Scan.Pop()
|
||||
// If decoding a flag, we need an argument.
|
||||
childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("map"), ';')...)
|
||||
if t.IsEOL() {
|
||||
return errors.Errorf("unexpected EOL")
|
||||
}
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
childScanner = Scan(SplitEscaped(v, ';')...)
|
||||
|
||||
case []map[string]interface{}:
|
||||
for _, m := range v {
|
||||
err := jsonTranscode(m, target.Addr().Interface())
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]interface{}:
|
||||
return jsonTranscode(v, target.Addr().Interface())
|
||||
|
||||
default:
|
||||
return errors.Errorf("invalid map value %q (of type %T)", t, t.Value)
|
||||
}
|
||||
} else {
|
||||
tokens := ctx.Scan.PopWhile(func(t Token) bool { return t.IsValue() })
|
||||
childScanner = ScanFromTokens(tokens...)
|
||||
}
|
||||
for !childScanner.Peek().IsEOL() {
|
||||
token := childScanner.PopValue("map")
|
||||
var token string
|
||||
_, err := childScanner.PopValueInto("map", &token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parts := strings.SplitN(token, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("expected \"<key>=<value>\" but got %q", token)
|
||||
return errors.Errorf("expected \"<key>=<value>\" but got %q", token)
|
||||
}
|
||||
key, value := parts[0], parts[1]
|
||||
|
||||
@@ -312,7 +402,7 @@ func mapDecoder(r *Registry) MapperFunc {
|
||||
if typ := ctx.Value.Tag.Type; typ != "" {
|
||||
parts := strings.Split(typ, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("type:\"\" on map field must be in the form \"[<keytype>]:[<valuetype>]\"")
|
||||
return errors.Errorf("type:\"\" on map field must be in the form \"[<keytype>]:[<valuetype>]\"")
|
||||
}
|
||||
keyTypeName, valueTypeName = parts[0], parts[1]
|
||||
}
|
||||
@@ -321,14 +411,14 @@ func mapDecoder(r *Registry) MapperFunc {
|
||||
keyDecoder := r.ForNamedType(keyTypeName, el.Key())
|
||||
keyValue := reflect.New(el.Key()).Elem()
|
||||
if err := keyDecoder.Decode(ctx.WithScanner(keyScanner), keyValue); err != nil {
|
||||
return fmt.Errorf("invalid map key %q", key)
|
||||
return errors.Errorf("invalid map key %q", key)
|
||||
}
|
||||
|
||||
valueScanner := Scan(value)
|
||||
valueDecoder := r.ForNamedType(valueTypeName, el.Elem())
|
||||
valueValue := reflect.New(el.Elem()).Elem()
|
||||
if err := valueDecoder.Decode(ctx.WithScanner(valueScanner), valueValue); err != nil {
|
||||
return fmt.Errorf("invalid map value %q", value)
|
||||
return errors.Errorf("invalid map value %q", value)
|
||||
}
|
||||
|
||||
target.SetMapIndex(keyValue, valueValue)
|
||||
@@ -343,21 +433,35 @@ func sliceDecoder(r *Registry) MapperFunc {
|
||||
sep := ctx.Value.Tag.Sep
|
||||
var childScanner *Scanner
|
||||
if ctx.Value.Flag != nil {
|
||||
t := ctx.Scan.Pop()
|
||||
// If decoding a flag, we need an argument.
|
||||
childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("list"), sep)...)
|
||||
if t.IsEOL() {
|
||||
return errors.Errorf("unexpected EOL")
|
||||
}
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
childScanner = Scan(SplitEscaped(v, sep)...)
|
||||
|
||||
case []interface{}:
|
||||
return jsonTranscode(v, target.Addr().Interface())
|
||||
|
||||
default:
|
||||
v = []interface{}{v}
|
||||
return jsonTranscode(v, target.Addr().Interface())
|
||||
}
|
||||
} else {
|
||||
tokens := ctx.Scan.PopWhile(func(t Token) bool { return t.IsValue() })
|
||||
childScanner = ScanFromTokens(tokens...)
|
||||
}
|
||||
childDecoder := r.ForNamedType(ctx.Value.Tag.Type, el)
|
||||
if childDecoder == nil {
|
||||
return fmt.Errorf("no mapper for element type of %s", target.Type())
|
||||
return errors.Errorf("no mapper for element type of %s", target.Type())
|
||||
}
|
||||
for !childScanner.Peek().IsEOL() {
|
||||
childValue := reflect.New(el).Elem()
|
||||
err := childDecoder.Decode(ctx.WithScanner(childScanner), childValue)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
target.Set(reflect.Append(target, childValue))
|
||||
}
|
||||
@@ -371,9 +475,13 @@ func pathMapper(r *Registry) MapperFunc {
|
||||
return sliceDecoder(r)(ctx, target)
|
||||
}
|
||||
if target.Kind() != reflect.String {
|
||||
return fmt.Errorf("\"path\" type must be applied to a string not %s", target.Type())
|
||||
return errors.Errorf("\"path\" type must be applied to a string not %s", target.Type())
|
||||
}
|
||||
var path string
|
||||
_, err := ctx.Scan.PopValueInto("file", &path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path := ctx.Scan.PopValue("file")
|
||||
path = ExpandPath(path)
|
||||
target.SetString(path)
|
||||
return nil
|
||||
@@ -386,16 +494,20 @@ func existingFileMapper(r *Registry) MapperFunc {
|
||||
return sliceDecoder(r)(ctx, target)
|
||||
}
|
||||
if target.Kind() != reflect.String {
|
||||
return fmt.Errorf("\"existingfile\" type must be applied to a string not %s", target.Type())
|
||||
return errors.Errorf("\"existingfile\" type must be applied to a string not %s", target.Type())
|
||||
}
|
||||
var path string
|
||||
_, err := ctx.Scan.PopValueInto("file", &path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path := ctx.Scan.PopValue("file")
|
||||
path = ExpandPath(path)
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stat.IsDir() {
|
||||
return fmt.Errorf("%q exists but is a directory", path)
|
||||
return errors.Errorf("%q exists but is a directory", path)
|
||||
}
|
||||
target.SetString(path)
|
||||
return nil
|
||||
@@ -408,16 +520,20 @@ func existingDirMapper(r *Registry) MapperFunc {
|
||||
return sliceDecoder(r)(ctx, target)
|
||||
}
|
||||
if target.Kind() != reflect.String {
|
||||
return fmt.Errorf("\"existingdir\" must be applied to a string not %s", target.Type())
|
||||
return errors.Errorf("\"existingdir\" must be applied to a string not %s", target.Type())
|
||||
}
|
||||
var path string
|
||||
_, err := ctx.Scan.PopValueInto("file", &path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path := ctx.Scan.PopValue("file")
|
||||
path = ExpandPath(path)
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !stat.IsDir() {
|
||||
return fmt.Errorf("%q exists but is not a directory", path)
|
||||
return errors.Errorf("%q exists but is not a directory", path)
|
||||
}
|
||||
target.SetString(path)
|
||||
return nil
|
||||
@@ -426,10 +542,15 @@ func existingDirMapper(r *Registry) MapperFunc {
|
||||
|
||||
func urlMapper() MapperFunc {
|
||||
return func(ctx *DecodeContext, target reflect.Value) error {
|
||||
url, err := url.Parse(ctx.Scan.PopValue("url"))
|
||||
var urlStr string
|
||||
_, err := ctx.Scan.PopValueInto("url", &urlStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
url, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
target.Set(reflect.ValueOf(url))
|
||||
return nil
|
||||
}
|
||||
@@ -479,11 +600,24 @@ func JoinEscaped(s []string, sep rune) string {
|
||||
type FileContentFlag []byte
|
||||
|
||||
func (f *FileContentFlag) Decode(ctx *DecodeContext) error { // nolint: golint
|
||||
filename := ExpandPath(ctx.Scan.PopValue("filename"))
|
||||
var filename string
|
||||
_, err := ctx.Scan.PopValueInto("filename", &filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filename = ExpandPath(filename)
|
||||
data, err := ioutil.ReadFile(filename) // nolint: gosec
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %q: %s", filename, err)
|
||||
return errors.Errorf("failed to open %q: %s", filename, err)
|
||||
}
|
||||
*f = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func jsonTranscode(in, out interface{}) error {
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
return errors.Wrapf(json.Unmarshal(data, out), "%#v -> %T", in, out)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user