calls_from_preds.go
package calls
import (
"encoding/csv"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"skraak/datafile"
"skraak/wav"
)
// Constants for clustering algorithm
const (
CLUSTER_GAP_MULTIPLIER = 2 // 3 Gap threshold = CLUSTER_GAP_MULTIPLIER * clip_duration. 3 for kiwi
MIN_DETECTIONS_PER_CLUSTER = 0 // 1 = filter out single detections (used for kiwi, they have long calls 30s), 0 = let single detections pass through
DEFAULT_CERTAINTY = 70 // .data certainty:70
DOT_DATA_WORKERS = 8 // Number of parallel workers for .data file writing
)
// CallsFromPredsInput defines the input for the calls-from-preds tool
type CallsFromPredsInput struct {
CSVPath string `json:"csv_path"`
Filter string `json:"filter"`
WriteDotData bool `json:"write_dot_data"`
GapMultiplier int `json:"gap_multiplier"`
MinDetections int `json:"min_detections"`
ProgressHandler ProgressHandler `json:"-"` // Optional progress callback (not serialized)
}
// ProgressHandler is a callback function for reporting progress during long operations
// processed: number of items processed so far
// total: total number of items to process
// message: optional status message
type ProgressHandler func(processed, total int, message string)
// CallsFromPredsOutput defines the output for the calls-from-preds tool
type CallsFromPredsOutput struct {
Calls []ClusteredCall `json:"calls"`
TotalCalls int `json:"total_calls"`
ClipDuration float64 `json:"clip_duration"`
GapThreshold float64 `json:"gap_threshold"`
SpeciesCount map[string]int `json:"species_count"`
DataFilesWritten int `json:"data_files_written"`
DataFilesSkipped int `json:"data_files_skipped"`
Filter string `json:"filter"`
Error *string `json:"error,omitempty"`
}
// CallsFromPreds reads a predictions CSV and clusters detections into continuous bird calls
func CallsFromPreds(input CallsFromPredsInput) (CallsFromPredsOutput, error) {
var output CallsFromPredsOutput
// Determine filter: use provided filter, or parse from CSV filename
filter := input.Filter
if filter == "" {
filter = ParseFilterFromFilename(input.CSVPath)
}
if filter == "" {
errMsg := "Filter must be specified via --filter flag or parsable from CSV filename"
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
output.Filter = filter
_, detections, clipDuration, err := readPredCSV(input.CSVPath)
if err != nil {
errMsg := err.Error()
output.Error = &errMsg
return output, err
}
output.ClipDuration = clipDuration
gapMultiplier := CLUSTER_GAP_MULTIPLIER
if input.GapMultiplier > 0 {
gapMultiplier = input.GapMultiplier
}
minDetections := MIN_DETECTIONS_PER_CLUSTER
if input.MinDetections >= 0 {
minDetections = input.MinDetections
}
gapThreshold := float64(gapMultiplier) * clipDuration
output.GapThreshold = gapThreshold
allCalls, speciesCount := clusterDetections(detections, clipDuration, gapThreshold, minDetections)
output.Calls = allCalls
output.TotalCalls = len(allCalls)
output.SpeciesCount = speciesCount
if input.WriteDotData {
dataFilesWritten, dataFilesSkipped, err := writeDotFiles(input.CSVPath, filter, allCalls, input.ProgressHandler)
if err != nil {
errMsg := fmt.Sprintf("Error writing .data files: %v", err)
output.Error = &errMsg
return output, fmt.Errorf("%s", errMsg)
}
output.DataFilesWritten = dataFilesWritten
output.DataFilesSkipped = dataFilesSkipped
}
return output, nil
}
// readPredCSV opens and reads a predictions CSV, returning column mappings, detections, and clip duration
func readPredCSV(csvPath string) (predCSVColumns, map[predFileSpeciesKey][]float64, float64, error) {
file, err := os.Open(csvPath)
if err != nil {
return predCSVColumns{}, nil, 0, fmt.Errorf("failed to open CSV file: %w", err)
}
defer func() { _ = file.Close() }()
reader := csv.NewReader(file)
reader.ReuseRecord = true
header, err := reader.Read()
if err != nil {
return predCSVColumns{}, nil, 0, fmt.Errorf("failed to read CSV header: %w", err)
}
cols, err := findPredCSVColumns(header)
if err != nil {
return predCSVColumns{}, nil, 0, err
}
detections, clipDuration, err := readPredCSVRows(reader, cols)
if err != nil {
return predCSVColumns{}, nil, 0, err
}
return cols, detections, clipDuration, nil
}
// predCSVColumns holds the column indices for a predictions CSV
type predCSVColumns struct {
fileIdx int
startTimeIdx int
endTimeIdx int
ebirdCodes []string
ebirdIdx []int
}
// findPredCSVColumns parses the CSV header to find column indices
func findPredCSVColumns(header []string) (predCSVColumns, error) {
cols := predCSVColumns{
fileIdx: -1,
startTimeIdx: -1,
endTimeIdx: -1,
}
ignoredColumns := map[string]bool{"NotKiwi": true, "0.0": true}
for i, col := range header {
switch col {
case "file":
cols.fileIdx = i
case "start_time":
cols.startTimeIdx = i
case "end_time":
cols.endTimeIdx = i
default:
if ignoredColumns[col] {
continue
}
cols.ebirdCodes = append(cols.ebirdCodes, col)
cols.ebirdIdx = append(cols.ebirdIdx, i)
}
}
if cols.fileIdx == -1 || cols.startTimeIdx == -1 || cols.endTimeIdx == -1 {
return cols, fmt.Errorf("CSV must have 'file', 'start_time', and 'end_time' columns")
}
if len(cols.ebirdCodes) == 0 {
return cols, fmt.Errorf("CSV must have at least one ebird code column")
}
return cols, nil
}
// readPredCSVRows reads all CSV data rows and returns detections grouped by file+species, plus clip duration
func readPredCSVRows(reader *csv.Reader, cols predCSVColumns) (map[predFileSpeciesKey][]float64, float64, error) {
detections := make(map[predFileSpeciesKey][]float64)
clipDuration := 0.0
record, err := reader.Read()
if err == io.EOF {
return detections, 0, nil
}
if err != nil {
return nil, 0, fmt.Errorf("failed to read first CSV row: %w", err)
}
startTime, _ := strconv.ParseFloat(record[cols.startTimeIdx], 64)
endTime, _ := strconv.ParseFloat(record[cols.endTimeIdx], 64)
clipDuration = endTime - startTime
addDetectionsFromRow(record, cols, startTime, detections)
for {
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, 0, fmt.Errorf("failed to read CSV row: %w", err)
}
startTime, _ = strconv.ParseFloat(record[cols.startTimeIdx], 64)
addDetectionsFromRow(record, cols, startTime, detections)
}
return detections, clipDuration, nil
}
// addDetectionsFromRow adds positive detections from a single CSV row
func addDetectionsFromRow(record []string, cols predCSVColumns, startTime float64, detections map[predFileSpeciesKey][]float64) {
fileName := record[cols.fileIdx]
for i, idx := range cols.ebirdIdx {
if record[idx] == "1" {
key := predFileSpeciesKey{File: fileName, EbirdCode: cols.ebirdCodes[i]}
detections[key] = append(detections[key], startTime)
}
}
}
// writeDotFiles writes AviaNZ .data files for each audio file with calls
// Uses parallel workers for improved performance on large batches
func writeDotFiles(csvPath, filter string, calls []ClusteredCall, progress ProgressHandler) (int, int, error) {
// Base directory is the directory containing the CSV file
csvDir := filepath.Dir(csvPath)
// Group calls by file (using extracted filename)
callsByFile := make(map[string][]ClusteredCall)
for _, call := range calls {
filename := filepath.Base(call.File)
callsByFile[filename] = append(callsByFile[filename], call)
}
// Report initial progress
if progress != nil {
progress(0, len(callsByFile), "Processing WAV files")
}
// If small batch, process sequentially (avoid goroutine overhead)
if len(callsByFile) < 10 {
return writeDotFilesSequential(csvDir, filter, callsByFile, progress)
}
// Parallel processing for larger batches
return writeDotFilesParallel(csvDir, filter, callsByFile, progress)
}
// dotDataJob represents a single file to process
type dotDataJob struct {
filename string
fileCalls []ClusteredCall
}
// dotDataResult represents the result of processing a single file
type dotDataResult struct {
filename string
written bool
err error
}
// writeDotFilesSequential processes files one at a time (for small batches)
func writeDotFilesSequential(csvDir, filter string, callsByFile map[string][]ClusteredCall, progress ProgressHandler) (int, int, error) {
dataFilesWritten := 0
dataFilesSkipped := 0
total := len(callsByFile)
processed := 0
for filename, fileCalls := range callsByFile {
// Find WAV file with correct case
baseName := strings.TrimSuffix(filename, filepath.Ext(filename))
wavPath := findWAVFile(csvDir, baseName)
if wavPath == "" {
dataFilesSkipped++
processed++
if progress != nil {
progress(processed, total, "")
}
continue
}
dataPath := wavPath + ".data"
sampleRate, duration, err := wav.ParseWAVHeaderMinimal(wavPath)
if err != nil {
dataFilesSkipped++
processed++
if progress != nil {
progress(processed, total, "")
}
continue
}
// Build segments and metadata
meta, segments := buildAviaNZMetaAndSegments(fileCalls, filter, duration, sampleRate)
if err := writeDotDataFileSafe(dataPath, segments, filter, meta); err != nil {
return dataFilesWritten, dataFilesSkipped, fmt.Errorf("failed to write %s: %w", dataPath, err)
}
dataFilesWritten++
processed++
if progress != nil {
progress(processed, total, "")
}
}
return dataFilesWritten, dataFilesSkipped, nil
}
// writeDotFilesParallel processes files concurrently using a worker pool
func writeDotFilesParallel(csvDir, filter string, callsByFile map[string][]ClusteredCall, progress ProgressHandler) (int, int, error) {
total := len(callsByFile)
var processed atomic.Int32
// Create job channel
jobs := make(chan dotDataJob, len(callsByFile))
results := make(chan dotDataResult, len(callsByFile))
// Start workers
var wg sync.WaitGroup
for range DOT_DATA_WORKERS {
wg.Add(1)
go dotDataWorker(csvDir, filter, jobs, results, &wg)
}
// Send jobs
for filename, fileCalls := range callsByFile {
jobs <- dotDataJob{filename: filename, fileCalls: fileCalls}
}
close(jobs)
// Wait for workers to finish
go func() {
wg.Wait()
close(results)
}()
// Collect results with progress reporting
dataFilesWritten := 0
dataFilesSkipped := 0
var firstErr error
for result := range results {
if result.err != nil && firstErr == nil {
firstErr = result.err
}
if result.written {
dataFilesWritten++
} else {
dataFilesSkipped++
}
// Report progress
if progress != nil {
current := int(processed.Add(1))
progress(current, total, "")
}
}
return dataFilesWritten, dataFilesSkipped, firstErr
}
// dotDataWorker processes files from the jobs channel
func dotDataWorker(csvDir, filter string, jobs <-chan dotDataJob, results chan<- dotDataResult, wg *sync.WaitGroup) {
defer wg.Done()
for job := range jobs {
// Find WAV file with correct case
baseName := strings.TrimSuffix(job.filename, filepath.Ext(job.filename))
wavPath := findWAVFile(csvDir, baseName)
if wavPath == "" {
results <- dotDataResult{filename: job.filename, written: false, err: nil}
continue
}
dataPath := wavPath + ".data"
sampleRate, duration, err := wav.ParseWAVHeaderMinimal(wavPath)
if err != nil {
results <- dotDataResult{filename: job.filename, written: false, err: nil}
continue
}
// Build segments and metadata
meta, segments := buildAviaNZMetaAndSegments(job.fileCalls, filter, duration, sampleRate)
if err := writeDotDataFileSafe(dataPath, segments, filter, meta); err != nil {
results <- dotDataResult{filename: job.filename, written: false, err: fmt.Errorf("failed to write %s: %w", dataPath, err)}
continue
}
results <- dotDataResult{filename: job.filename, written: true, err: nil}
}
}
// buildAviaNZMetaAndSegments creates metadata and segments for a .data file
func buildAviaNZMetaAndSegments(calls []ClusteredCall, filter string, duration float64, sampleRate int) (AviaNZMeta, []AviaNZSegment) {
// Create metadata
reviewer := "None"
meta := AviaNZMeta{
Operator: "Auto",
Reviewer: &reviewer,
Duration: duration,
}
// Build segments array
var segments []AviaNZSegment
for _, call := range calls {
// Create labels for this segment
labels := []AviaNZLabel{
{
Species: call.EbirdCode,
Certainty: DEFAULT_CERTAINTY,
Filter: filter,
},
}
// Create segment: [start, end, freq_low, freq_high, labels]
// freq_low=0, freq_high=sampleRate for full-band segments
segment := AviaNZSegment{
call.StartTime,
call.EndTime,
0, // freq_low
sampleRate, // freq_high (full band)
labels,
}
segments = append(segments, segment)
}
return meta, segments
}
// writeAviaNZDataFile writes a new .data file to disk (does not check for existing files)
func writeAviaNZDataFile(path string, data []any) error {
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func() { _ = file.Close() }()
encoder := json.NewEncoder(file)
encoder.SetIndent("", "") // No indentation for compact output
if err := encoder.Encode(data); err != nil {
return fmt.Errorf("failed to encode JSON: %w", err)
}
return nil
}
// writeDotDataFileSafe safely writes or merges .data files
// - If file doesn't exist: write new file
// - If file exists with same filter: return error (refuse to clobber)
// - If file exists with different filter: merge segments and write
// - If file exists but can't be parsed: return error (refuse to clobber)
func writeDotDataFileSafe(path string, newSegments []AviaNZSegment, filter string, meta AviaNZMeta) error {
// Check if file exists
if _, err := os.Stat(path); err == nil {
// File exists - parse and check
existing, err := datafile.ParseDataFile(path)
if err != nil {
return fmt.Errorf("cannot parse existing %s: %w (refusing to clobber)", path, err)
}
// Check for duplicate filter
for _, seg := range existing.Segments {
if seg.HasFilterLabel(filter) {
return fmt.Errorf("%s already contains filter '%s' (refusing to clobber)", path, filter)
}
}
// Append new segments (different filter - safe to merge)
for _, newSeg := range newSegments {
seg := convertAviaNZSegment(newSeg, filter)
existing.Segments = append(existing.Segments, seg)
}
// Sort by start time
sort.Slice(existing.Segments, func(i, j int) bool {
return existing.Segments[i].StartTime < existing.Segments[j].StartTime
})
return existing.Write(path)
}
// File doesn't exist - write new
data := buildDataFileFromSegments(meta, newSegments)
return writeAviaNZDataFile(path, data)
}
// convertAviaNZSegment converts an AviaNZSegment to datafile.Segment
func convertAviaNZSegment(seg AviaNZSegment, filter string) *datafile.Segment {
labels := seg[4].([]AviaNZLabel)
utilsLabels := make([]*datafile.Label, len(labels))
for i, l := range labels {
utilsLabels[i] = &datafile.Label{
Species: l.Species,
Certainty: l.Certainty,
Filter: filter,
}
}
// Handle freq values (could be int or float64 depending on how they were created)
var freqLow, freqHigh float64
switch v := seg[2].(type) {
case int:
freqLow = float64(v)
case float64:
freqLow = v
}
switch v := seg[3].(type) {
case int:
freqHigh = float64(v)
case float64:
freqHigh = v
}
return &datafile.Segment{
StartTime: seg[0].(float64),
EndTime: seg[1].(float64),
FreqLow: freqLow,
FreqHigh: freqHigh,
Labels: utilsLabels,
}
}
// buildDataFileFromSegments builds the data file structure from meta and segments
func buildDataFileFromSegments(meta AviaNZMeta, segments []AviaNZSegment) []any {
result := make([]any, 0, 1+len(segments))
result = append(result, meta)
for _, seg := range segments {
result = append(result, seg)
}
return result
}