package resources
import (
"context"
"fmt"
"os"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
var schemaPath string
var tableNames = []string{
"dataset",
"location",
"cyclic_recording_pattern",
"cluster",
"file",
"moth_metadata",
"file_metadata",
"file_dataset",
"selection",
"selection_metadata",
"ebird_taxonomy",
"species",
"call_type",
"filter",
"label",
"label_subtype",
"ebird_taxonomy_v2024",
"species_dataset",
}
func SetSchemaPath(path string) {
schemaPath = path
}
func GetSchemaResources() (*mcp.Resource, *mcp.ResourceTemplate) {
fullSchemaResource := &mcp.Resource{
URI: "schema://full",
Name: "Database Schema",
Description: "Complete SQL schema for the skraak database including all tables, indexes, and types",
MIMEType: "application/sql",
}
tableTemplate := &mcp.ResourceTemplate{
URITemplate: "schema://table/{table_name}",
Name: "Table Schema",
Description: "SQL schema for a specific table. Available tables: dataset, location, cyclic_recording_pattern, cluster, file, moth_metadata, file_metadata, file_dataset, selection, selection_metadata, ebird_taxonomy, species, call_type, filter, label, label_subtype, ebird_taxonomy_v2024, species_dataset",
MIMEType: "application/sql",
}
return fullSchemaResource, tableTemplate
}
func SchemaResourceHandler(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
uri := req.Params.URI
if uri == "schema://full" {
return readFullSchema()
}
if strings.HasPrefix(uri, "schema://table/") {
tableName := strings.TrimPrefix(uri, "schema://table/")
return readTableSchema(tableName)
}
return nil, fmt.Errorf("unknown resource URI: %s", uri)
}
func readFullSchema() (*mcp.ReadResourceResult, error) {
if schemaPath == "" {
return nil, fmt.Errorf("schema path not set")
}
content, err := os.ReadFile(schemaPath)
if err != nil {
return nil, fmt.Errorf("failed to read schema file: %w", err)
}
return &mcp.ReadResourceResult{
Contents: []*mcp.ResourceContents{
{
URI: "schema://full",
MIMEType: "application/sql",
Text: string(content),
},
},
}, nil
}
func readTableSchema(tableName string) (*mcp.ReadResourceResult, error) {
if schemaPath == "" {
return nil, fmt.Errorf("schema path not set")
}
if !isValidTableName(tableName) {
return nil, fmt.Errorf("invalid table name: %s. Valid tables: %s", tableName, strings.Join(tableNames, ", "))
}
content, err := os.ReadFile(schemaPath)
if err != nil {
return nil, fmt.Errorf("failed to read schema file: %w", err)
}
tableDef, err := extractTableDefinition(string(content), tableName)
if err != nil {
return nil, err
}
return &mcp.ReadResourceResult{
Contents: []*mcp.ResourceContents{
{
URI: fmt.Sprintf("schema://table/%s", tableName),
MIMEType: "application/sql",
Text: tableDef,
},
},
}, nil
}
func isValidTableName(name string) bool {
for _, validName := range tableNames {
if name == validName {
return true
}
}
return false
}
func extractTableDefinition(schema string, tableName string) (string, error) {
lines := strings.Split(schema, "\n")
var tableLines []string
inTable := false
parenCount := 0
isView := false
for _, line := range lines {
if strings.Contains(line, "CREATE TABLE "+tableName) ||
strings.Contains(line, "CREATE TABLE "+tableName+" AS") ||
strings.Contains(line, "CREATE TYPE "+tableName) {
inTable = true
tableLines = append(tableLines, line)
if strings.Contains(line, " AS") {
isView = true
}
parenCount += strings.Count(line, "(") - strings.Count(line, ")")
if isView && strings.HasSuffix(strings.TrimSpace(line), ";") {
break
}
continue
}
if inTable {
tableLines = append(tableLines, line)
parenCount += strings.Count(line, "(") - strings.Count(line, ")")
if isView {
if strings.HasSuffix(strings.TrimSpace(line), ";") {
break
}
} else {
if parenCount == 0 && strings.Contains(line, ");") {
break
}
}
}
}
if len(tableLines) == 0 {
return "", fmt.Errorf("table definition not found: %s", tableName)
}
tableLines = append(tableLines, "")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.Contains(trimmed, "CREATE INDEX") && strings.Contains(trimmed, " "+tableName+"(") {
tableLines = append(tableLines, line)
}
if strings.Contains(trimmed, "ALTER TABLE "+tableName) {
tableLines = append(tableLines, line)
}
}
return strings.Join(tableLines, "\n"), nil
}