Feature: Add xand tag (#442)

* Feat: Add xand group and check for missing

* Fix: Split and combine err in TestMultiand for consistency

* Feat: Check missing required flags in xand groups

* Feat: Handle combined xor and xand

* Docs: Add info about combined xand and required use

* Docs: Fix language error in xand description

Co-authored-by: Stautis <thkrst@gmail.com>

* Feat: Rename xand to and

* Refactor: Switch from fmt.Sprintf to err.Error

* Refactor: Get requiredAndGroup map in separate function

---------

Co-authored-by: Stautis <thkrst@gmail.com>
This commit is contained in:
Camilla
2024-08-08 08:58:22 +02:00
committed by GitHub
parent 5f9c5cc822
commit ff6d5ba7d5
7 changed files with 195 additions and 14 deletions
+77 -2
View File
@@ -259,7 +259,7 @@ func (c *Context) Validate() error { //nolint: gocyclo
if err := checkMissingPositionals(positionals, node.Positional); err != nil {
return err
}
if err := checkXorDuplicates(c.Path); err != nil {
if err := checkXorDuplicatedAndAndMissing(c.Path); err != nil {
return err
}
@@ -831,23 +831,35 @@ func (c *Context) PrintUsage(summary bool) error {
func checkMissingFlags(flags []*Flag) error {
xorGroupSet := map[string]bool{}
xorGroup := map[string][]string{}
andGroupSet := map[string]bool{}
andGroup := map[string][]string{}
missing := []string{}
andGroupRequired := getRequiredAndGroupMap(flags)
for _, flag := range flags {
for _, and := range flag.And {
flag.Required = andGroupRequired[and]
}
if flag.Set {
for _, xor := range flag.Xor {
xorGroupSet[xor] = true
}
for _, and := range flag.And {
andGroupSet[and] = true
}
}
if !flag.Required || flag.Set {
continue
}
if len(flag.Xor) > 0 {
if len(flag.Xor) > 0 || len(flag.And) > 0 {
for _, xor := range flag.Xor {
if xorGroupSet[xor] {
continue
}
xorGroup[xor] = append(xorGroup[xor], flag.Summary())
}
for _, and := range flag.And {
andGroup[and] = append(andGroup[and], flag.Summary())
}
} else {
missing = append(missing, flag.Summary())
}
@@ -857,6 +869,11 @@ func checkMissingFlags(flags []*Flag) error {
missing = append(missing, strings.Join(flags, " or "))
}
}
for _, flags := range andGroup {
if len(flags) > 1 {
missing = append(missing, strings.Join(flags, " and "))
}
}
if len(missing) == 0 {
return nil
@@ -867,6 +884,18 @@ func checkMissingFlags(flags []*Flag) error {
return fmt.Errorf("missing flags: %s", strings.Join(missing, ", "))
}
func getRequiredAndGroupMap(flags []*Flag) map[string]bool {
andGroupRequired := map[string]bool{}
for _, flag := range flags {
for _, and := range flag.And {
if flag.Required {
andGroupRequired[and] = true
}
}
}
return andGroupRequired
}
func checkMissingChildren(node *Node) error {
missing := []string{}
@@ -977,6 +1006,20 @@ func checkPassthroughArg(target reflect.Value) bool {
}
}
func checkXorDuplicatedAndAndMissing(paths []*Path) error {
errs := []string{}
if err := checkXorDuplicates(paths); err != nil {
errs = append(errs, err.Error())
}
if err := checkAndMissing(paths); err != nil {
errs = append(errs, err.Error())
}
if len(errs) > 0 {
return fmt.Errorf(strings.Join(errs, ", "))
}
return nil
}
func checkXorDuplicates(paths []*Path) error {
for _, path := range paths {
seen := map[string]*Flag{}
@@ -995,6 +1038,38 @@ func checkXorDuplicates(paths []*Path) error {
return nil
}
func checkAndMissing(paths []*Path) error {
for _, path := range paths {
missingMsgs := []string{}
andGroups := map[string][]*Flag{}
for _, flag := range path.Flags {
for _, and := range flag.And {
andGroups[and] = append(andGroups[and], flag)
}
}
for _, flags := range andGroups {
oneSet := false
notSet := []*Flag{}
flagNames := []string{}
for _, flag := range flags {
flagNames = append(flagNames, flag.Name)
if flag.Set {
oneSet = true
} else {
notSet = append(notSet, flag)
}
}
if len(notSet) > 0 && oneSet {
missingMsgs = append(missingMsgs, fmt.Sprintf("--%s must be used together", strings.Join(flagNames, " and --")))
}
}
if len(missingMsgs) > 0 {
return fmt.Errorf("%s", strings.Join(missingMsgs, ", "))
}
}
return nil
}
func findPotentialCandidates(needle string, haystack []string, format string, args ...interface{}) error {
if len(haystack) == 0 {
return fmt.Errorf(format, args...)