package tools
import (
"context"
"fmt"
"strings"
"skraak_mcp/db"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type CreateClusterInput struct {
DatasetID string `json:"dataset_id" jsonschema:"required,ID of the parent dataset (12-character nanoid)"`
LocationID string `json:"location_id" jsonschema:"required,ID of the parent location (12-character nanoid)"`
Name string `json:"name" jsonschema:"required,Cluster name (max 140 characters)"`
SampleRate int `json:"sample_rate" jsonschema:"required,Sample rate in Hz (must be positive)"`
CyclicRecordingPatternID *string `json:"cyclic_recording_pattern_id,omitempty" jsonschema:"Optional ID of cyclic recording pattern (12-character nanoid)"`
Description *string `json:"description,omitempty" jsonschema:"Optional cluster description (max 255 characters)"`
}
type CreateClusterOutput struct {
Cluster db.Cluster `json:"cluster" jsonschema:"The created cluster with generated ID and timestamps"`
Message string `json:"message" jsonschema:"Success message"`
}
func CreateCluster(
ctx context.Context,
req *mcp.CallToolRequest,
input CreateClusterInput,
) (*mcp.CallToolResult, CreateClusterOutput, error) {
var output CreateClusterOutput
if strings.TrimSpace(input.Name) == "" {
return nil, output, fmt.Errorf("name cannot be empty")
}
if len(input.Name) > 140 {
return nil, output, fmt.Errorf("name must be 140 characters or less (got %d)", len(input.Name))
}
if input.Description != nil && len(*input.Description) > 255 {
return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
}
if input.SampleRate <= 0 {
return nil, output, fmt.Errorf("sample_rate must be positive (got %d)", input.SampleRate)
}
if strings.TrimSpace(input.DatasetID) == "" {
return nil, output, fmt.Errorf("dataset_id cannot be empty")
}
if strings.TrimSpace(input.LocationID) == "" {
return nil, output, fmt.Errorf("location_id cannot be empty")
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
var datasetExists bool
var datasetActive bool
var datasetName string
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM dataset WHERE id = ?), active, name FROM dataset WHERE id = ?",
input.DatasetID, input.DatasetID,
).Scan(&datasetExists, &datasetActive, &datasetName)
if err != nil {
return nil, output, fmt.Errorf("failed to verify dataset: %w", err)
}
if !datasetExists {
return nil, output, fmt.Errorf("dataset with ID '%s' does not exist", input.DatasetID)
}
if !datasetActive {
return nil, output, fmt.Errorf("dataset '%s' (ID: %s) is not active", datasetName, input.DatasetID)
}
var locationExists bool
var locationActive bool
var locationName string
var locationDatasetID string
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM location WHERE id = ?), active, name, dataset_id FROM location WHERE id = ?",
input.LocationID, input.LocationID,
).Scan(&locationExists, &locationActive, &locationName, &locationDatasetID)
if err != nil {
return nil, output, fmt.Errorf("failed to verify location: %w", err)
}
if !locationExists {
return nil, output, fmt.Errorf("location with ID '%s' does not exist", input.LocationID)
}
if !locationActive {
return nil, output, fmt.Errorf("location '%s' (ID: %s) is not active", locationName, input.LocationID)
}
if locationDatasetID != input.DatasetID {
return nil, output, fmt.Errorf("location '%s' (ID: %s) does not belong to dataset '%s' (ID: %s) - it belongs to dataset ID '%s'",
locationName, input.LocationID, datasetName, input.DatasetID, locationDatasetID)
}
if input.CyclicRecordingPatternID != nil && strings.TrimSpace(*input.CyclicRecordingPatternID) != "" {
var patternExists bool
var patternActive bool
err = tx.QueryRowContext(ctx,
"SELECT EXISTS(SELECT 1 FROM cyclic_recording_pattern WHERE id = ?), active FROM cyclic_recording_pattern WHERE id = ?",
*input.CyclicRecordingPatternID, *input.CyclicRecordingPatternID,
).Scan(&patternExists, &patternActive)
if err != nil {
return nil, output, fmt.Errorf("failed to verify cyclic recording pattern: %w", err)
}
if !patternExists {
return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' does not exist", *input.CyclicRecordingPatternID)
}
if !patternActive {
return nil, output, fmt.Errorf("cyclic recording pattern with ID '%s' is not active", *input.CyclicRecordingPatternID)
}
}
id, err := db.GenerateID()
if err != nil {
return nil, output, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO cluster (id, dataset_id, location_id, name, sample_rate, cyclic_recording_pattern_id, description, created_at, last_modified, active) VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, input.DatasetID, input.LocationID, input.Name, input.SampleRate, input.CyclicRecordingPatternID, input.Description,
)
if err != nil {
return nil, output, fmt.Errorf("failed to create cluster: %w", err)
}
var cluster db.Cluster
err = tx.QueryRowContext(ctx,
"SELECT id, dataset_id, location_id, name, description, created_at, last_modified, active, cyclic_recording_pattern_id, sample_rate FROM cluster WHERE id = ?",
id,
).Scan(&cluster.ID, &cluster.DatasetID, &cluster.LocationID, &cluster.Name, &cluster.Description,
&cluster.CreatedAt, &cluster.LastModified, &cluster.Active, &cluster.CyclicRecordingPatternID, &cluster.SampleRate)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch created cluster: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Cluster = cluster
output.Message = fmt.Sprintf("Successfully created cluster '%s' with ID %s in location '%s' at dataset '%s' (sample rate: %d Hz)",
cluster.Name, cluster.ID, locationName, datasetName, cluster.SampleRate)
return &mcp.CallToolResult{}, output, nil
}