package tools
import (
"context"
"fmt"
"strings"
"skraak_mcp/db"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type CreateDatasetInput struct {
Name string `json:"name" jsonschema:"required,Dataset name (max 255 characters)"`
Description *string `json:"description,omitempty" jsonschema:"Optional dataset description (max 255 characters)"`
Type *string `json:"type,omitempty" jsonschema:"Dataset type: 'organise'/'test'/'train' (defaults to 'organise')"`
}
type CreateDatasetOutput struct {
Dataset db.Dataset `json:"dataset" jsonschema:"The created dataset with generated ID and timestamps"`
Message string `json:"message" jsonschema:"Success message"`
}
func CreateDataset(
ctx context.Context,
req *mcp.CallToolRequest,
input CreateDatasetInput,
) (*mcp.CallToolResult, CreateDatasetOutput, error) {
var output CreateDatasetOutput
if strings.TrimSpace(input.Name) == "" {
return nil, output, fmt.Errorf("name cannot be empty")
}
if len(input.Name) > 255 {
return nil, output, fmt.Errorf("name must be 255 characters or less (got %d)", len(input.Name))
}
if input.Description != nil && len(*input.Description) > 255 {
return nil, output, fmt.Errorf("description must be 255 characters or less (got %d)", len(*input.Description))
}
datasetType := db.DatasetTypeOrganise if input.Type != nil {
typeStr := strings.ToLower(strings.TrimSpace(*input.Type))
switch typeStr {
case "organise":
datasetType = db.DatasetTypeOrganise
case "test":
datasetType = db.DatasetTypeTest
case "train":
datasetType = db.DatasetTypeTrain
default:
return nil, output, fmt.Errorf("invalid type '%s': must be 'organise', 'test', or 'train'", *input.Type)
}
}
database, err := db.OpenWriteableDB(dbPath)
if err != nil {
return nil, output, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return nil, output, fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
id, err := db.GenerateID()
if err != nil {
return nil, output, fmt.Errorf("failed to generate ID: %w", err)
}
_, err = tx.ExecContext(ctx,
"INSERT INTO dataset (id, name, description, type, created_at, last_modified, active) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, TRUE)",
id, input.Name, input.Description, string(datasetType),
)
if err != nil {
return nil, output, fmt.Errorf("failed to create dataset: %w", err)
}
var dataset db.Dataset
err = tx.QueryRowContext(ctx,
"SELECT id, name, description, created_at, last_modified, active, type FROM dataset WHERE id = ?",
id,
).Scan(&dataset.ID, &dataset.Name, &dataset.Description, &dataset.CreatedAt, &dataset.LastModified, &dataset.Active, &dataset.Type)
if err != nil {
return nil, output, fmt.Errorf("failed to fetch created dataset: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, output, fmt.Errorf("failed to commit transaction: %w", err)
}
output.Dataset = dataset
output.Message = fmt.Sprintf("Successfully created dataset '%s' with ID %s (type: %s)",
dataset.Name, dataset.ID, dataset.Type)
return &mcp.CallToolResult{}, output, nil
}