package tools
import (
"context"
"database/sql"
"encoding/base64"
"fmt"
"regexp"
"strings"
"time"
"skraak_mcp/db"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
var dbPath string
func SetDBPath(path string) {
dbPath = path
}
type ExecuteSQLInput struct {
Query string `json:"query" jsonschema:"required,SQL SELECT query to execute"`
Parameters []interface{} `json:"parameters,omitempty" jsonschema:"Optional parameters for parameterized queries (use ? placeholders)"`
Limit *int `json:"limit,omitempty" jsonschema:"Maximum rows to return (default 1000 max 10000)"`
}
type ColumnInfo struct {
Name string `json:"name" jsonschema:"Column name"`
DatabaseType string `json:"database_type" jsonschema:"Database type of the column"`
}
type ExecuteSQLOutput struct {
Rows []map[string]interface{} `json:"rows" jsonschema:"Query result rows"`
RowCount int `json:"row_count" jsonschema:"Number of rows returned"`
Columns []ColumnInfo `json:"columns" jsonschema:"Column metadata"`
Limited bool `json:"limited" jsonschema:"Whether results were truncated due to row limit"`
Query string `json:"query_executed" jsonschema:"The actual query executed (with LIMIT applied)"`
}
var (
selectPattern = regexp.MustCompile(`(?i)^\s*(SELECT|WITH)\s+`)
forbiddenPattern = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE|GRANT|REVOKE)\b`)
limitPattern = regexp.MustCompile(`(?i)\bLIMIT\s+\d+`)
)
const (
defaultLimit = 1000
maxLimit = 10000
)
func ExecuteSQL(
ctx context.Context,
req *mcp.CallToolRequest,
input ExecuteSQLInput,
) (*mcp.CallToolResult, ExecuteSQLOutput, error) {
if strings.TrimSpace(input.Query) == "" {
return nil, ExecuteSQLOutput{}, fmt.Errorf("query cannot be empty")
}
if !selectPattern.MatchString(input.Query) {
return nil, ExecuteSQLOutput{}, fmt.Errorf("only SELECT and WITH queries are allowed")
}
if forbiddenPattern.MatchString(input.Query) {
return nil, ExecuteSQLOutput{}, fmt.Errorf("query contains forbidden keywords (INSERT/UPDATE/DELETE/DROP/CREATE/ALTER)")
}
limit := defaultLimit
if input.Limit != nil {
if *input.Limit < 1 || *input.Limit > maxLimit {
return nil, ExecuteSQLOutput{}, fmt.Errorf("limit must be between 1 and %d", maxLimit)
}
limit = *input.Limit
}
query := input.Query
if !limitPattern.MatchString(query) {
query = fmt.Sprintf("%s LIMIT %d", strings.TrimSpace(query), limit)
}
database, err := db.OpenReadOnlyDB(dbPath)
if err != nil {
return nil, ExecuteSQLOutput{}, fmt.Errorf("database connection failed: %w", err)
}
defer database.Close()
var rows *sql.Rows
if len(input.Parameters) > 0 {
rows, err = database.QueryContext(ctx, query, input.Parameters...)
} else {
rows, err = database.QueryContext(ctx, query)
}
if err != nil {
return nil, ExecuteSQLOutput{}, fmt.Errorf("query execution failed: %w", err)
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, ExecuteSQLOutput{}, fmt.Errorf("failed to get columns: %w", err)
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, ExecuteSQLOutput{}, fmt.Errorf("failed to get column types: %w", err)
}
columnInfo := make([]ColumnInfo, len(columns))
for i, col := range columns {
columnInfo[i] = ColumnInfo{
Name: col,
DatabaseType: columnTypes[i].DatabaseTypeName(),
}
}
var results []map[string]interface{}
rowCount := 0
limited := false
for rows.Next() {
if rowCount >= limit {
limited = true
break
}
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, ExecuteSQLOutput{}, fmt.Errorf("row scan failed: %w", err)
}
rowMap := make(map[string]interface{})
for i, col := range columns {
rowMap[col] = convertValue(values[i])
}
results = append(results, rowMap)
rowCount++
}
if err = rows.Err(); err != nil {
return nil, ExecuteSQLOutput{}, fmt.Errorf("row iteration failed: %w", err)
}
if results == nil {
results = []map[string]interface{}{}
}
output := ExecuteSQLOutput{
Rows: results,
RowCount: rowCount,
Columns: columnInfo,
Limited: limited,
Query: query,
}
return &mcp.CallToolResult{}, output, nil
}
func convertValue(val interface{}) interface{} {
if val == nil {
return nil
}
switch v := val.(type) {
case time.Time:
return v.Format(time.RFC3339)
case []byte:
return base64.StdEncoding.EncodeToString(v)
case int64, float64, string, bool:
return v
default:
return fmt.Sprintf("%v", v)
}
}