common_test.go
package cmd
import (
"flag"
"io"
"strings"
"testing"
)
func silentFlagSet() *flag.FlagSet {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs.Usage = func() {}
return fs
}
func TestRequireFlags(t *testing.T) {
tests := []struct {
name string
required map[string]any
wantErr bool
wantMissing []string
}{
{"all string present", map[string]any{"--db": "x.duckdb", "--id": "abc"}, false, nil},
{"empty string", map[string]any{"--db": "", "--id": "abc"}, true, []string{"--db"}},
{"multiple missing strings", map[string]any{"--db": "", "--id": ""}, true, []string{"--db", "--id"}},
{"no flags", map[string]any{}, false, nil},
{"all int non-zero", map[string]any{"--n": 5, "--m": 1}, false, nil},
{"zero int", map[string]any{"--n": 0, "--m": 1}, true, []string{"--n"}},
{"multiple zero ints", map[string]any{"--n": 0, "--m": 0}, true, []string{"--m", "--n"}},
{"negative int is allowed", map[string]any{"--n": -1}, false, nil},
{"mixed types", map[string]any{"--db": "x", "--n": 0}, true, []string{"--n"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := requireFlags(silentFlagSet(), tt.required)
if (err != nil) != tt.wantErr {
t.Fatalf("err=%v, wantErr=%v", err, tt.wantErr)
}
if tt.wantErr {
for _, name := range tt.wantMissing {
if !strings.Contains(err.Error(), name) {
t.Errorf("err %q missing flag name %q", err.Error(), name)
}
}
}
})
}
}
func TestRequireFlagsUnsupportedType(t *testing.T) {
err := requireFlags(silentFlagSet(), map[string]any{"--bad": 1.5})
if err == nil || !strings.Contains(err.Error(), "unsupported type") {
t.Fatalf("expected unsupported type error, got %v", err)
}
}