package tools

import (
	"context"
	"database/sql"
	"encoding/base64"
	"fmt"
	"regexp"
	"strings"
	"time"

	"skraak_mcp/db"

	"github.com/modelcontextprotocol/go-sdk/mcp"
)

// Package-level variable to store database path
var dbPath string

// SetDBPath sets the database path for the tools package
// Called from main.go during initialization
func SetDBPath(path string) {
	dbPath = path
}

// ExecuteSQLInput defines the input parameters for the execute_sql tool
type ExecuteSQLInput struct {
	Query      string        `json:"query" jsonschema:"required,SQL SELECT query to execute"`
	Parameters []interface{} `json:"parameters,omitempty" jsonschema:"Optional parameters for parameterized queries (use ? placeholders)"`
	Limit      *int          `json:"limit,omitempty" jsonschema:"Maximum rows to return (default 1000 max 10000)"`
}

// ColumnInfo contains metadata about a result column
type ColumnInfo struct {
	Name         string `json:"name" jsonschema:"Column name"`
	DatabaseType string `json:"database_type" jsonschema:"Database type of the column"`
}

// ExecuteSQLOutput defines the output structure for the execute_sql tool
type ExecuteSQLOutput struct {
	Rows     []map[string]interface{} `json:"rows" jsonschema:"Query result rows"`
	RowCount int                      `json:"row_count" jsonschema:"Number of rows returned"`
	Columns  []ColumnInfo             `json:"columns" jsonschema:"Column metadata"`
	Limited  bool                     `json:"limited" jsonschema:"Whether results were truncated due to row limit"`
	Query    string                   `json:"query_executed" jsonschema:"The actual query executed (with LIMIT applied)"`
}

// Validation patterns
var (
	// Must start with SELECT or WITH (case-insensitive, allows leading whitespace)
	selectPattern = regexp.MustCompile(`(?i)^\s*(SELECT|WITH)\s+`)

	// Check for forbidden keywords that might indicate write operations
	forbiddenPattern = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE|GRANT|REVOKE)\b`)

	// Check for existing LIMIT clause (case-insensitive)
	limitPattern = regexp.MustCompile(`(?i)\bLIMIT\s+\d+`)
)

const (
	defaultLimit = 1000
	maxLimit     = 10000
)

// ExecuteSQL implements the execute_sql tool handler
// Executes arbitrary SQL SELECT queries with safety validation
func ExecuteSQL(
	ctx context.Context,
	req *mcp.CallToolRequest,
	input ExecuteSQLInput,
) (*mcp.CallToolResult, ExecuteSQLOutput, error) {
	// Validate query is not empty
	if strings.TrimSpace(input.Query) == "" {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("query cannot be empty")
	}

	// Validate query starts with SELECT or WITH
	if !selectPattern.MatchString(input.Query) {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("only SELECT and WITH queries are allowed")
	}

	// Check for forbidden keywords (defense in depth - database is already read-only)
	if forbiddenPattern.MatchString(input.Query) {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("query contains forbidden keywords (INSERT/UPDATE/DELETE/DROP/CREATE/ALTER)")
	}

	// Determine row limit
	limit := defaultLimit
	if input.Limit != nil {
		if *input.Limit < 1 || *input.Limit > maxLimit {
			return nil, ExecuteSQLOutput{}, fmt.Errorf("limit must be between 1 and %d", maxLimit)
		}
		limit = *input.Limit
	}

	// Add LIMIT clause if not present
	query := input.Query
	if !limitPattern.MatchString(query) {
		query = fmt.Sprintf("%s LIMIT %d", strings.TrimSpace(query), limit)
	}

	// Get database connection (read-only for security)
	database, err := db.OpenReadOnlyDB(dbPath)
	if err != nil {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("database connection failed: %w", err)
	}
	defer database.Close() // Always close when done

	// Execute query with parameters
	var rows *sql.Rows
	if len(input.Parameters) > 0 {
		rows, err = database.QueryContext(ctx, query, input.Parameters...)
	} else {
		rows, err = database.QueryContext(ctx, query)
	}
	if err != nil {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("query execution failed: %w", err)
	}
	defer rows.Close()

	// Get column metadata
	columns, err := rows.Columns()
	if err != nil {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("failed to get columns: %w", err)
	}

	columnTypes, err := rows.ColumnTypes()
	if err != nil {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("failed to get column types: %w", err)
	}

	// Build column info
	columnInfo := make([]ColumnInfo, len(columns))
	for i, col := range columns {
		columnInfo[i] = ColumnInfo{
			Name:         col,
			DatabaseType: columnTypes[i].DatabaseTypeName(),
		}
	}

	// Process rows
	var results []map[string]interface{}
	rowCount := 0
	limited := false

	for rows.Next() {
		// Check if we've hit the limit
		if rowCount >= limit {
			limited = true
			break
		}

		// Create slice to hold column values
		values := make([]interface{}, len(columns))
		valuePtrs := make([]interface{}, len(columns))
		for i := range values {
			valuePtrs[i] = &values[i]
		}

		// Scan row
		if err := rows.Scan(valuePtrs...); err != nil {
			return nil, ExecuteSQLOutput{}, fmt.Errorf("row scan failed: %w", err)
		}

		// Convert to map with type conversion
		rowMap := make(map[string]interface{})
		for i, col := range columns {
			rowMap[col] = convertValue(values[i])
		}

		results = append(results, rowMap)
		rowCount++
	}

	// Check for errors during iteration
	if err = rows.Err(); err != nil {
		return nil, ExecuteSQLOutput{}, fmt.Errorf("row iteration failed: %w", err)
	}

	// Handle empty results (return empty array, not error)
	if results == nil {
		results = []map[string]interface{}{}
	}

	// Create output structure
	output := ExecuteSQLOutput{
		Rows:     results,
		RowCount: rowCount,
		Columns:  columnInfo,
		Limited:  limited,
		Query:    query,
	}

	return &mcp.CallToolResult{}, output, nil
}

// convertValue converts database values to JSON-friendly types
func convertValue(val interface{}) interface{} {
	if val == nil {
		return nil
	}

	switch v := val.(type) {
	case time.Time:
		// Format timestamps as RFC3339 strings (consistent with existing code)
		return v.Format(time.RFC3339)
	case []byte:
		// Convert binary data to base64
		return base64.StdEncoding.EncodeToString(v)
	case int64, float64, string, bool:
		// Pass through primitive types
		return v
	default:
		// For unknown types, convert to string
		return fmt.Sprintf("%v", v)
	}
}