package tools

import (
	"context"
	"fmt"
	"strings"

	"skraak_mcp/db"

	"github.com/modelcontextprotocol/go-sdk/mcp"
)

// CreateClusterInput defines the input parameters for the create_cluster tool
type CreateClusterInput struct {
	DatasetID                string  `json:"dataset_id" jsonschema:"required,ID of the parent dataset (12-character nanoid)"`
	LocationID               string  `json:"location_id" jsonschema:"required,ID of the parent location (12-character nanoid)"`
	Name                     string  `json:"name" jsonschema:"required,Cluster name (max 140 characters)"`
	SampleRate               int     `json:"sample_rate" jsonschema:"required,Sample rate in Hz (must be positive)"`
	CyclicRecordingPatternID *string `json:"cyclic_recording_pattern_id,omitempty" jsonschema:"Optional ID of cyclic recording pattern (12-character nanoid)"`
	Description              *string `json:"description,omitempty" jsonschema:"Optional cluster description (max 255 characters)"`
}

// CreateClusterOutput defines the output structure
type CreateClusterOutput struct {
	Cluster db.Cluster `json:"cluster" jsonschema:"The created cluster with generated ID and timestamps"`
	Message string     `json:"message" jsonschema:"Success message"`
}

// CreateCluster implements the create_cluster tool handler
// Creates a new cluster within a location. Location must belong to the specified dataset.
func CreateCluster(
	ctx context.Context,
	req *mcp.CallToolRequest,
	input CreateClusterInput,
) (*mcp.CallToolResult, CreateClusterOutput, error) {
	var output CreateClusterOutput

	// Validate name
	if strings.TrimSpace(input.Name) == "" {
		return nil, output, fmt.Errorf("name cannot be empty")
	}
	if len(input.Name) > 140 {
		return nil, output, fmt.Errorf("name must be 140 characters or less (got %d)", len(input.Name))
	}

	// Validate description length if provided
	if input.Description != nil && len(*input.Description) > 255 {
		return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
	}

	// Validate sample rate
	if input.SampleRate <= 0 {
		return nil, output, fmt.Errorf("sample_rate must be positive (got %d)", input.SampleRate)
	}

	// Validate IDs not empty
	if strings.TrimSpace(input.DatasetID) == "" {
		return nil, output, fmt.Errorf("dataset_id cannot be empty")
	}
	if strings.TrimSpace(input.LocationID) == "" {
		return nil, output, fmt.Errorf("location_id cannot be empty")
	}

	// Open writable database connection
	database, err := db.OpenWriteableDB(dbPath)
	if err != nil {
		return nil, output, fmt.Errorf("database connection failed: %w", err)
	}
	defer database.Close()

	// Begin transaction
	tx, err := database.BeginTx(ctx, nil)
	if err != nil {
		return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			tx.Rollback()
		}
	}()

	// Verify dataset exists and is active
	var datasetExists bool
	var datasetActive bool
	var datasetName string
	err = tx.QueryRowContext(ctx,
		"SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), active, name FROM dataset WHERE id = ?",
		input.DatasetID, input.DatasetID,
	).Scan(&datasetExists, &datasetActive, &datasetName)
	if err != nil {
		return nil, output, fmt.Errorf("failed to verify dataset: %w", err)
	}
	if !datasetExists {
		return nil, output, fmt.Errorf("dataset with ID '%s' does not exist", input.DatasetID)
	}
	if !datasetActive {
		return nil, output, fmt.Errorf("dataset '%s' (ID: %s) is not active", datasetName, input.DatasetID)
	}

	// Verify location exists, is active, and belongs to the specified dataset (BUSINESS RULE)
	var locationExists bool
	var locationActive bool
	var locationName string
	var locationDatasetID string
	err = tx.QueryRowContext(ctx,
		"SELECT EXISTS(SELECT 1 FROM location WHERE id = ?), active, name, dataset_id FROM location WHERE id = ?",
		input.LocationID, input.LocationID,
	).Scan(&locationExists, &locationActive, &locationName, &locationDatasetID)
	if err != nil {
		return nil, output, fmt.Errorf("failed to verify location: %w", err)
	}
	if !locationExists {
		return nil, output, fmt.Errorf("location with ID '%s' does not exist", input.LocationID)
	}
	if !locationActive {
		return nil, output, fmt.Errorf("location '%s' (ID: %s) is not active", locationName, input.LocationID)
	}
	// CRITICAL BUSINESS RULE: Location must belong to the specified dataset
	if locationDatasetID != input.DatasetID {
		return nil, output, fmt.Errorf("location '%s' (ID: %s) does not belong to dataset '%s' (ID: %s) - it belongs to dataset ID '%s'",
			locationName, input.LocationID, datasetName, input.DatasetID, locationDatasetID)
	}

	// Verify cyclic recording pattern if provided
	if input.CyclicRecordingPatternID != nil && strings.TrimSpace(*input.CyclicRecordingPatternID) != "" {
		var patternExists bool
		var patternActive bool
		err = tx.QueryRowContext(ctx,
			"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), active FROM cyclic_recording_pattern WHERE id = ?",
			*input.CyclicRecordingPatternID, *input.CyclicRecordingPatternID,
		).Scan(&patternExists, &patternActive)
		if err != nil {
			return nil, output, fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
		}
		if !patternExists {
			return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' does not exist", *input.CyclicRecordingPatternID)
		}
		if !patternActive {
			return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' is not active", *input.CyclicRecordingPatternID)
		}
	}

	// Generate ID
	id, err := db.GenerateID()
	if err != nil {
		return nil, output, fmt.Errorf("failed to generate ID: %w", err)
	}

	// Insert cluster (explicitly set timestamps and active for schema compatibility)
	_, err = tx.ExecContext(ctx,
		"INSERT INTO cluster (id, dataset_id, location_id, name, sample_rate, cyclic_recording_pattern_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
		id, input.DatasetID, input.LocationID, input.Name, input.SampleRate, input.CyclicRecordingPatternID, input.Description,
	)
	if err != nil {
		return nil, output, fmt.Errorf("failed to create cluster: %w", err)
	}

	// Fetch the created cluster (gets DB-generated timestamps and defaults)
	var cluster db.Cluster
	err = tx.QueryRowContext(ctx,
		"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
		id,
	).Scan(&cluster.ID, &cluster.DatasetID, &cluster.LocationID, &cluster.Name, &cluster.Description,
		&cluster.CreatedAt, &cluster.LastModified, &cluster.Active, &cluster.CyclicRecordingPatternID, &cluster.SampleRate)
	if err != nil {
		return nil, output, fmt.Errorf("failed to fetch created cluster: %w", err)
	}

	// Commit transaction
	if err = tx.Commit(); err != nil {
		return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
	}

	output.Cluster = cluster
	output.Message = fmt.Sprintf("Successfully created cluster '%s' with ID %s in location '%s' at dataset '%s' (sample rate: %d Hz)",
		cluster.Name, cluster.ID, locationName, datasetName, cluster.SampleRate)

	return &mcp.CallToolResult{}, output, nil
}