package cmd

import (
	"bufio"
	"context"
	"database/sql"
	"encoding/json"
	"flag"
	"fmt"
	"os"
	"strings"

	"skraak/db"
)

// RunReplay handles the "replay" subcommand
func RunReplay(args []string) {
	if len(args) < 1 {
		printReplayUsage()
		os.Exit(1)
	}

	switch args[0] {
	case "events":
		runReplayEvents(args[1:])
	default:
		fmt.Fprintf(os.Stderr, "Unknown replay subcommand: %s\n\n", args[0])
		printReplayUsage()
		os.Exit(1)
	}
}

func printReplayUsage() {
	fmt.Fprintf(os.Stderr, "Usage: skraak replay <subcommand> [options]\n\n")
	fmt.Fprintf(os.Stderr, "Subcommands:\n")
	fmt.Fprintf(os.Stderr, "  events    Replay event log into database\n")
	fmt.Fprintf(os.Stderr, "\nExamples:\n")
	fmt.Fprintf(os.Stderr, "  skraak replay events --db ./backup.duckdb --log ./skraak.duckdb.events.jsonl\n")
	fmt.Fprintf(os.Stderr, "  skraak replay events --db ./backup.duckdb --log ./events.jsonl --dry-run\n")
	fmt.Fprintf(os.Stderr, "  skraak replay events --db ./backup.duckdb --log ./events.jsonl --last 10\n")
}

func runReplayEvents(args []string) {
	fs := flag.NewFlagSet("replay events", flag.ExitOnError)
	dbPath := fs.String("db", "", "Path to target database (required)")
	logPath := fs.String("log", "", "Path to event log file (required)")
	dryRun := fs.Bool("dry-run", false, "Print events without executing")
	fromID := fs.String("from", "", "Start from event ID (inclusive)")
	toID := fs.String("to", "", "Stop at event ID (inclusive)")
	lastN := fs.Int("last", 0, "Replay last N events (0 = all)")
	continueOnError := fs.Bool("continue", false, "Continue past errors")

	fs.Usage = func() {
		fmt.Fprintf(os.Stderr, "Usage: skraak replay events [options]\n\n")
		fmt.Fprintf(os.Stderr, "Replay event log into database.\n\n")
		fmt.Fprintf(os.Stderr, "Options:\n")
		fs.PrintDefaults()
		fmt.Fprintf(os.Stderr, "\nExamples:\n")
		fmt.Fprintf(os.Stderr, "  skraak replay events --db ./backup.duckdb --log ./events.jsonl\n")
		fmt.Fprintf(os.Stderr, "  skraak replay events --db ./backup.duckdb --log ./events.jsonl --dry-run\n")
		fmt.Fprintf(os.Stderr, "  skraak replay events --db ./backup.duckdb --log ./events.jsonl --last 10\n")
	}

	if err := fs.Parse(args); err != nil {
		os.Exit(1)
	}

	// Validate required flags
	missing := []string{}
	if *dbPath == "" {
		missing = append(missing, "--db")
	}
	if *logPath == "" {
		missing = append(missing, "--log")
	}
	if len(missing) > 0 {
		fmt.Fprintf(os.Stderr, "Error: missing required flags: %v\n\n", missing)
		fs.Usage()
		os.Exit(1)
	}

	// Read events
	events, err := readEvents(*logPath)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Error reading events: %v\n", err)
		os.Exit(1)
	}

	// Filter events
	events = filterEvents(events, *fromID, *toID, *lastN)

	fmt.Fprintf(os.Stderr, "Found %d events to replay\n", len(events))

	if *dryRun {
		for i, event := range events {
			fmt.Printf("\n[%d/%d] Event %s (%s)\n", i+1, len(events), event.ID, event.Tool)
			for _, q := range event.Queries {
				fmt.Printf("  SQL: %s\n", truncateSQL(q.SQL, 80))
				fmt.Printf("  Params: %v\n", q.Parameters)
			}
		}
		return
	}

	// Open database
	database, err := db.OpenWriteableDB(*dbPath)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Error opening database: %v\n", err)
		os.Exit(1)
	}
	defer database.Close()

	// Disable event logging for replay
	db.SetEventLogConfig(db.EventLogConfig{Enabled: false})

	// Replay each event
	successCount := 0
	failCount := 0

	for i, event := range events {
		fmt.Fprintf(os.Stderr, "\n[%d/%d] Replaying event %s (%s)...\n", i+1, len(events), event.ID, event.Tool)

		err := replayEvent(database, event)
		if err != nil {
			failCount++
			fmt.Fprintf(os.Stderr, "  ERROR: %v\n", err)
			if !*continueOnError {
				fmt.Fprintf(os.Stderr, "Stopping due to error. Use --continue to skip errors.\n")
				os.Exit(1)
			}
		} else {
			successCount++
			fmt.Fprintf(os.Stderr, "  OK (%d queries)\n", len(event.Queries))
		}
	}

	fmt.Fprintf(os.Stderr, "\nReplay complete: %d succeeded, %d failed\n", successCount, failCount)
}

// TransactionEvent represents a transaction event from the log
type TransactionEvent struct {
	ID        string        `json:"id"`
	Timestamp string        `json:"timestamp"`
	Tool      string        `json:"tool,omitempty"`
	Queries   []QueryRecord `json:"queries"`
	Success   bool          `json:"success"`
	Duration  int64         `json:"duration_ms"`
}

// QueryRecord represents a single SQL statement with parameters
type QueryRecord struct {
	SQL        string `json:"sql"`
	Parameters []any  `json:"parameters"`
}

// readEvents reads all events from a JSONL file
func readEvents(path string) ([]TransactionEvent, error) {
	file, err := os.Open(path)
	if err != nil {
		return nil, fmt.Errorf("failed to open event log: %w", err)
	}
	defer file.Close()

	var events []TransactionEvent
	scanner := bufio.NewScanner(file)
	scanner.Buffer(make([]byte, 20*1024*1024), 20*1024*1024) // 20MB max line size
	lineNum := 0

	for scanner.Scan() {
		lineNum++
		line := scanner.Bytes()
		if len(line) == 0 {
			continue
		}

		var event TransactionEvent
		if err := json.Unmarshal(line, &event); err != nil {
			fmt.Fprintf(os.Stderr, "Warning: failed to parse line %d: %v\n", lineNum, err)
			continue
		}

		events = append(events, event)
	}

	if err := scanner.Err(); err != nil {
		return nil, fmt.Errorf("error reading event log: %w", err)
	}

	return events, nil
}

// filterEvents filters events based on criteria
func filterEvents(events []TransactionEvent, fromID, toID string, lastN int) []TransactionEvent {
	// Filter by fromID
	if fromID != "" {
		startIdx := 0
		for i, e := range events {
			if e.ID == fromID {
				startIdx = i
				break
			}
		}
		events = events[startIdx:]
	}

	// Filter by toID
	if toID != "" {
		endIdx := len(events)
		for i, e := range events {
			if e.ID == toID {
				endIdx = i + 1
				break
			}
		}
		events = events[:endIdx]
	}

	// Filter by lastN
	if lastN > 0 && len(events) > lastN {
		events = events[len(events)-lastN:]
	}

	// Only replay successful events
	var filtered []TransactionEvent
	for _, e := range events {
		if e.Success {
			filtered = append(filtered, e)
		}
	}

	return filtered
}

// replayEvent replays a single transaction event
func replayEvent(database *sql.DB, event TransactionEvent) error {
	ctx := context.Background()
	tx, err := database.BeginTx(ctx, nil)
	if err != nil {
		return fmt.Errorf("failed to begin transaction: %w", err)
	}

	for _, q := range event.Queries {
		// Convert parameters to []interface{} for Exec
		_, err := tx.ExecContext(ctx, q.SQL, q.Parameters...)
		if err != nil {
			tx.Rollback()
			return fmt.Errorf("query failed: %w (SQL: %s)", err, truncateSQL(q.SQL, 50))
		}
	}

	if err := tx.Commit(); err != nil {
		return fmt.Errorf("failed to commit transaction: %w", err)
	}

	return nil
}

// truncateSQL truncates a SQL string for display
func truncateSQL(sql string, maxLen int) string {
	sql = strings.Join(strings.Fields(sql), " ") // Normalize whitespace
	if len(sql) <= maxLen {
		return sql
	}
	return sql[:maxLen] + "..."
}