package tools
import (
"encoding/csv"
"fmt"
"io"
"os"
"sort"
"strconv"
)
// Constants for clustering algorithm
const (
CLUSTER_GAP_MULTIPLIER = 3 // Gap threshold = CLUSTER_GAP_MULTIPLIER * clip_duration
MIN_DETECTIONS_PER_CLUSTER = 1 // Minimum detections per cluster (1 = filter single detections)
)
// ClusteredCall represents a clustered bird call detection
type ClusteredCall struct {
File string `json:"file"`
StartTime float64 `json:"start_time"`
EndTime float64 `json:"end_time"`
EbirdCode string `json:"ebird_code"`
Detections int `json:"detections"`
}
// CallsFromPredsInput defines the input for the calls-from-preds tool
type CallsFromPredsInput struct {
CSVPath string `json:"csv_path" jsonschema:"required,Path to predictions CSV file"`
}
// CallsFromPredsOutput defines the output for the calls-from-preds tool
type CallsFromPredsOutput struct {
Calls []ClusteredCall `json:"calls"`
TotalCalls int `json:"total_calls"`
TotalClusters int `json:"total_clusters"`
ClipDuration float64 `json:"clip_duration"`
GapThreshold float64 `json:"gap_threshold"`
SpeciesCount map[string]int `json:"species_count"`
FilesCount int `json:"files_count"`
Error *string `json:"error,omitempty"`
}
// CallsFromPreds reads a predictions CSV and clusters detections into continuous bird calls
func CallsFromPreds(input CallsFromPredsInput) (CallsFromPredsOutput, error) {
var output CallsFromPredsOutput
// Open CSV file
file, err := os.Open(input.CSVPath)
if err != nil {
errMsg := fmt.Sprintf("Failed to open CSV file: %v", err)
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
defer file.Close()
// Read CSV
reader := csv.NewReader(file)
reader.ReuseRecord = true // Memory optimization for large files
// Read header
header, err := reader.Read()
if err != nil {
errMsg := fmt.Sprintf("Failed to read CSV header: %v", err)
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
// Find column indices
fileIdx := -1
startTimeIdx := -1
endTimeIdx := -1
var ebirdCodes []string
var ebirdIdx []int
for i, col := range header {
switch col {
case "file":
fileIdx = i
case "start_time":
startTimeIdx = i
case "end_time":
endTimeIdx = i
default:
// All other columns are ebird codes
ebirdCodes = append(ebirdCodes, col)
ebirdIdx = append(ebirdIdx, i)
}
}
if fileIdx == -1 || startTimeIdx == -1 || endTimeIdx == -1 {
errMsg := "CSV must have 'file', 'start_time', and 'end_time' columns"
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
if len(ebirdCodes) == 0 {
errMsg := "CSV must have at least one ebird code column"
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
// Read all rows and organize by (file, ebird_code) -> start_times
// Using maps for efficient grouping
type FileEbirdKey struct {
File string
EbirdCode string
}
detections := make(map[FileEbirdKey][]float64)
clipDuration := 0.0
filesSeen := make(map[string]bool)
// Read first row to get clip duration
record, err := reader.Read()
if err != nil && err != io.EOF {
errMsg := fmt.Sprintf("Failed to read first CSV row: %v", err)
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
if err != io.EOF {
startTime, _ := strconv.ParseFloat(record[startTimeIdx], 64)
endTime, _ := strconv.ParseFloat(record[endTimeIdx], 64)
clipDuration = endTime - startTime
output.ClipDuration = clipDuration
// Process first row
fileName := record[fileIdx]
filesSeen[fileName] = true
for i, idx := range ebirdIdx {
if record[idx] == "1" {
key := FileEbirdKey{File: fileName, EbirdCode: ebirdCodes[i]}
detections[key] = append(detections[key], startTime)
}
}
// Read remaining rows
for {
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
errMsg := fmt.Sprintf("Failed to read CSV row: %v", err)
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
startTime, _ := strconv.ParseFloat(record[startTimeIdx], 64)
fileName := record[fileIdx]
filesSeen[fileName] = true
for i, idx := range ebirdIdx {
if record[idx] == "1" {
key := FileEbirdKey{File: fileName, EbirdCode: ebirdCodes[i]}
detections[key] = append(detections[key], startTime)
}
}
}
}
// Calculate gap threshold
gapThreshold := float64(CLUSTER_GAP_MULTIPLIER) * clipDuration
output.GapThreshold = gapThreshold
output.FilesCount = len(filesSeen)
// Cluster detections by (file, ebird_code)
var allCalls []ClusteredCall
speciesCount := make(map[string]int)
for key, startTimes := range detections {
// Sort start times
sort.Float64s(startTimes)
// Cluster consecutive detections
clusters := clusterStartTimes(startTimes, gapThreshold)
// Convert clusters to calls
for _, cluster := range clusters {
if len(cluster) <= MIN_DETECTIONS_PER_CLUSTER {
continue
}
call := ClusteredCall{
File: key.File,
StartTime: cluster[0],
EndTime: cluster[len(cluster)-1] + clipDuration,
EbirdCode: key.EbirdCode,
Detections: len(cluster),
}
allCalls = append(allCalls, call)
speciesCount[key.EbirdCode]++
}
}
// Sort calls by file, then start time
sort.Slice(allCalls, func(i, j int) bool {
if allCalls[i].File != allCalls[j].File {
return allCalls[i].File < allCalls[j].File
}
return allCalls[i].StartTime < allCalls[j].StartTime
})
output.Calls = allCalls
output.TotalCalls = len(allCalls)
output.TotalClusters = len(allCalls)
output.SpeciesCount = speciesCount
return output, nil
}
// clusterStartTimes groups consecutive start times into clusters
// where the gap between consecutive times is <= gapThreshold
func clusterStartTimes(startTimes []float64, gapThreshold float64) [][]float64 {
if len(startTimes) == 0 {
return nil
}
var clusters [][]float64
currentCluster := []float64{startTimes[0]}
for i := 1; i < len(startTimes); i++ {
gap := startTimes[i] - startTimes[i-1]
if gap <= gapThreshold {
// Same cluster
currentCluster = append(currentCluster, startTimes[i])
} else {
// New cluster
clusters = append(clusters, currentCluster)
currentCluster = []float64{startTimes[i]}
}
}
// Don't forget the last cluster
clusters = append(clusters, currentCluster)
return clusters
}