package cmd

import (
	"flag"
	"fmt"
	"os"
	"strconv"

	"skraak/db"
)

// initEventLog configures transaction event logging for the given database path.
// Returns a cleanup function that should be deferred by the caller.
func initEventLog(dbPath string) func() {
	db.SetEventLogConfig(db.EventLogConfig{
		Enabled: true,
		Path:    dbPath + ".events.jsonl",
	})
	return func() {
		if err := db.CloseEventLog(); err != nil {
			fmt.Fprintf(os.Stderr, "Warning: failed to close event log: %v\n", err)
		}
	}
}

// checkFlags checks that the given flag values are non-empty strings.
// Returns an error if any are empty (does not call os.Exit).
// Each pair is (flagName, flagValue) — e.g. checkFlags(fs, "--db", *dbPath, "--id", *id)
func checkFlags(fs *flag.FlagSet, pairs ...string) error {
	var missing []string
	for i := 0; i < len(pairs); i += 2 {
		if pairs[i+1] == "" {
			missing = append(missing, pairs[i])
		}
	}
	if len(missing) > 0 {
		fs.Usage()
		return fmt.Errorf("missing required flags: %v", missing)
	}
	return nil
}

// checkNonZeroFlags checks that the given int flag values are non-zero.
// Returns an error if any are zero (does not call os.Exit).
func checkNonZeroFlags(fs *flag.FlagSet, pairs ...struct {
	Name  string
	Value int
}) error {
	var missing []string
	for _, p := range pairs {
		if p.Value == 0 {
			missing = append(missing, p.Name)
		}
	}
	if len(missing) > 0 {
		fs.Usage()
		return fmt.Errorf("missing required flags: %v", missing)
	}
	return nil
}

// mustValue returns the value for a flag that requires an argument, advancing i by 2.
// Exits the program if the value is missing (unrecoverable CLI error).
func mustValue(args []string, i *int, flag string) string {
	if *i+1 >= len(args) {
		fmt.Fprintf(os.Stderr, "Error: %s requires a value\n", flag)
		os.Exit(1)
	}
	v := args[*i+1]
	*i += 2
	return v
}

// mustIntValue returns the integer value for a flag, validating range [lo,hi].
// Exits on missing value, parse error, or out-of-range.
func mustIntValue(args []string, i *int, flag string, lo, hi int) int {
	val := mustValue(args, i, flag)
	v, err := strconv.Atoi(val)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Error: %s must be an integer\n", flag)
		os.Exit(1)
	}
	if v < lo || v > hi {
		fmt.Fprintf(os.Stderr, "Error: %s must be between %d and %d\n", flag, lo, hi)
		os.Exit(1)
	}
	return v
}