Fork channel

Create a new channel as a copy of main.

Rename channel

Rename main to:

Delete channel

Delete main? This cannot be undone.

import_segments_prepare.go
package imp

import (
	"context"
	"database/sql"
	"fmt"
	"os"
	"path/filepath"
	"strings"

	"skraak/datafile"
	"skraak/db"
	"skraak/utils"
)

// validateAndPrepareSegments performs phases B+C: parse data files, validate DB state, and prepare ID maps.
func validateAndPrepareSegments(
	q db.Querier,
	input ImportSegmentsInput,
	mapping MappingFile,
	dataFiles []string,
) (*segmentValidation, []ImportSegmentError, error) {
	// Phase B: Parse all .data files and collect unique values
	scannedFiles, parseErrors, uniqueFilters, uniqueSpecies, uniqueCalltypes := scanAllDataFiles(dataFiles, input.Folder)
	if len(scannedFiles) == 0 {
		return nil, parseErrors, nil
	}

	// Validate dataset/location/cluster hierarchy
	if err := validateSegmentHierarchy(q, input.DatasetID, input.LocationID, input.ClusterID); err != nil {
		return nil, parseErrors, err
	}

	// Validate all filters exist
	filterIDMap, err := validateFiltersExist(q, uniqueFilters)
	if err != nil {
		return nil, parseErrors, fmt.Errorf("filter validation failed: %w", err)
	}

	// Validate mapping covers all species/calltypes and they exist in DB
	validationResult, err := ValidateMappingAgainstDB(q, mapping, uniqueSpecies, uniqueCalltypes)
	if err != nil {
		return nil, parseErrors, fmt.Errorf("mapping validation failed: %w", err)
	}
	if validationResult.HasErrors() {
		return nil, parseErrors, fmt.Errorf("mapping validation failed: %s", validationResult.Error())
	}

	// Load species and calltype ID maps
	speciesIDMap, calltypeIDMap, err := loadSpeciesCalltypeIDs(q, mapping, uniqueSpecies, uniqueCalltypes)
	if err != nil {
		return nil, parseErrors, fmt.Errorf("failed to load species/calltype IDs: %w", err)
	}

	// Validate files: hash exists, linked to dataset, no existing labels
	fileIDMap, hashErrors := validateAndMapFiles(q, scannedFiles, input.ClusterID, input.DatasetID)
	allErrors := append(parseErrors, hashErrors...)

	return &segmentValidation{
		scannedFiles:  scannedFiles,
		filterIDMap:   filterIDMap,
		speciesIDMap:  speciesIDMap,
		calltypeIDMap: calltypeIDMap,
		fileIDMap:     fileIDMap,
	}, allErrors, nil
}

// validateSegmentImportInput validates input parameters
func validateSegmentImportInput(input ImportSegmentsInput) error {
	// Validate folder exists
	if info, err := os.Stat(input.Folder); err != nil {
		return fmt.Errorf("folder does not exist: %s", input.Folder)
	} else if !info.IsDir() {
		return fmt.Errorf("path is not a folder: %s", input.Folder)
	}

	// Validate mapping file exists
	if _, err := os.Stat(input.Mapping); err != nil {
		return fmt.Errorf("mapping file does not exist: %s", input.Mapping)
	}

	// Validate IDs
	if err := utils.ValidateShortID(input.DatasetID, "dataset_id"); err != nil {
		return err
	}
	if err := utils.ValidateShortID(input.LocationID, "location_id"); err != nil {
		return err
	}
	if err := utils.ValidateShortID(input.ClusterID, "cluster_id"); err != nil {
		return err
	}

	return nil
}

// validateSegmentHierarchy validates dataset/location/cluster relationships
func validateSegmentHierarchy(q db.Querier, datasetID, locationID, clusterID string) error {
	if err := db.ValidateDatasetTypeForImport(q, datasetID); err != nil {
		return err
	}

	if err := db.ValidateLocationBelongsToDataset(q, locationID, datasetID); err != nil {
		return err
	}

	if err := db.ClusterBelongsToLocation(q, clusterID, locationID); err != nil {
		return err
	}

	return nil
}

// scanAllDataFiles parses all .data files and collects unique values
func scanAllDataFiles(dataFiles []string, folder string) (
	[]scannedDataFile,
	[]ImportSegmentError,
	map[string]bool,
	map[string]bool,
	map[string]map[string]bool,
) {
	var scanned []scannedDataFile
	var errors []ImportSegmentError
	uniqueFilters := make(map[string]bool)
	uniqueSpecies := make(map[string]bool)
	uniqueCalltypes := make(map[string]map[string]bool) // species -> calltype -> true

	for _, dataPath := range dataFiles {
		// Find corresponding WAV file
		wavPath := strings.TrimSuffix(dataPath, ".data")
		if _, err := os.Stat(wavPath); err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(dataPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("corresponding WAV file not found: %s", filepath.Base(wavPath)),
			})
			continue
		}

		// Parse .data file
		df, err := datafile.ParseDataFile(dataPath)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(dataPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("failed to parse .data file: %v", err),
			})
			continue
		}

		// Collect unique filters, species, calltypes
		for _, seg := range df.Segments {
			for _, label := range seg.Labels {
				uniqueFilters[label.Filter] = true
				uniqueSpecies[label.Species] = true
				if label.CallType != "" {
					if uniqueCalltypes[label.Species] == nil {
						uniqueCalltypes[label.Species] = make(map[string]bool)
					}
					uniqueCalltypes[label.Species][label.CallType] = true
				}
			}
		}

		scanned = append(scanned, scannedDataFile{
			DataPath: dataPath,
			WavPath:  wavPath,
			Duration: df.Meta.Duration,
			Segments: df.Segments,
		})
	}

	return scanned, errors, uniqueFilters, uniqueSpecies, uniqueCalltypes
}

// validateFiltersExist checks all filters exist in DB and returns ID map
func validateFiltersExist(q db.Querier, filterNames map[string]bool) (map[string]string, error) {
	filterIDMap := make(map[string]string)

	if len(filterNames) == 0 {
		return filterIDMap, nil
	}

	names := make([]string, 0, len(filterNames))
	for name := range filterNames {
		names = append(names, name)
	}

	query := `SELECT id, name FROM filter WHERE name IN (` + db.Placeholders(len(names)) + `) AND active = true`
	args := make([]any, len(names))
	for i, name := range names {
		args[i] = name
	}

	rows, err := q.QueryContext(context.Background(), query, args...)
	if err != nil {
		return nil, fmt.Errorf("failed to query filters: %w", err)
	}
	defer rows.Close()

	for rows.Next() {
		var id, name string
		if err := rows.Scan(&id, &name); err == nil {
			filterIDMap[name] = id
		}
	}

	// Check for missing filters
	var missing []string
	for name := range filterNames {
		if _, exists := filterIDMap[name]; !exists {
			missing = append(missing, name)
		}
	}

	if len(missing) > 0 {
		return nil, fmt.Errorf("filters not found in database: [%s]", strings.Join(missing, ", "))
	}

	return filterIDMap, nil
}

// loadSpeciesCalltypeIDs loads species and calltype ID maps
func loadSpeciesCalltypeIDs(
	q db.Querier,
	mapping MappingFile,
	uniqueSpecies map[string]bool,
	uniqueCalltypes map[string]map[string]bool,
) (map[string]string, map[string]map[string]string, error) {
	speciesIDMap, err := loadSpeciesIDs(q, mapping, uniqueSpecies)
	if err != nil {
		return nil, nil, err
	}

	calltypeIDMap, err := loadCalltypeIDs(q, mapping, uniqueCalltypes)
	if err != nil {
		return nil, nil, err
	}

	return speciesIDMap, calltypeIDMap, nil
}

// loadSpeciesIDs queries the DB for species IDs matching the mapped species labels.
func loadSpeciesIDs(q db.Querier, mapping MappingFile, uniqueSpecies map[string]bool) (map[string]string, error) {
	speciesIDMap := make(map[string]string)

	dbSpeciesSet := make(map[string]bool)
	for dataSpecies := range uniqueSpecies {
		if dbSpecies, ok := mapping.GetDBSpecies(dataSpecies); ok {
			dbSpeciesSet[dbSpecies] = true
		}
	}

	if len(dbSpeciesSet) == 0 {
		return speciesIDMap, nil
	}

	dbSpeciesList := make([]string, 0, len(dbSpeciesSet))
	for s := range dbSpeciesSet {
		dbSpeciesList = append(dbSpeciesList, s)
	}

	query := `SELECT id, label FROM species WHERE label IN (` + db.Placeholders(len(dbSpeciesList)) + `) AND active = true`
	args := make([]any, len(dbSpeciesList))
	for i, s := range dbSpeciesList {
		args[i] = s
	}

	rows, err := q.QueryContext(context.Background(), query, args...)
	if err != nil {
		return nil, fmt.Errorf("failed to query species: %w", err)
	}
	defer rows.Close()

	for rows.Next() {
		var id, label string
		if err := rows.Scan(&id, &label); err == nil {
			speciesIDMap[label] = id
		}
	}

	return speciesIDMap, nil
}

// loadCalltypeIDs queries the DB for calltype IDs matching the mapped calltype labels.
func loadCalltypeIDs(q db.Querier, mapping MappingFile, uniqueCalltypes map[string]map[string]bool) (map[string]map[string]string, error) {
	calltypeIDMap := make(map[string]map[string]string)

	for dataSpecies, ctSet := range uniqueCalltypes {
		dbSpecies, ok := mapping.GetDBSpecies(dataSpecies)
		if !ok {
			continue
		}

		if calltypeIDMap[dbSpecies] == nil {
			calltypeIDMap[dbSpecies] = make(map[string]string)
		}

		for dataCalltype := range ctSet {
			dbCalltype := mapping.GetDBCalltype(dataSpecies, dataCalltype)

			var calltypeID string
			err := q.QueryRowContext(context.Background(), `
				SELECT ct.id
				FROM call_type ct
				JOIN species s ON ct.species_id = s.id
				WHERE s.label = ? AND ct.label = ? AND ct.active = true
			`, dbSpecies, dbCalltype).Scan(&calltypeID)

			if err == nil {
				calltypeIDMap[dbSpecies][dbCalltype] = calltypeID
			}
		}
	}

	return calltypeIDMap, nil
}

// validateAndMapFiles validates files exist by hash, are linked to dataset, and have no existing labels
func validateAndMapFiles(
	q db.Querier,
	scannedFiles []scannedDataFile,
	clusterID string,
	datasetID string,
) (map[string]scannedDataFile, []ImportSegmentError) {
	fileIDMap := make(map[string]scannedDataFile)
	var errors []ImportSegmentError

	for _, sf := range scannedFiles {
		// Compute hash
		hash, err := utils.ComputeXXH64(sf.WavPath)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageHash,
				Message: fmt.Sprintf("failed to compute hash: %v", err),
			})
			continue
		}
		sf.WavHash = hash

		// Find file by hash in cluster
		var fileID string
		var duration float64
		err = q.QueryRowContext(context.Background(), `
			SELECT id, duration FROM file WHERE xxh64_hash = ? AND cluster_id = ? AND active = true
		`, hash, clusterID).Scan(&fileID, &duration)

		if err == sql.ErrNoRows {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("file hash not found in database for cluster (hash: %s)", hash),
			})
			continue
		}
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("failed to query file: %v", err),
			})
			continue
		}

		sf.FileID = fileID
		sf.Duration = duration

		// Verify file is linked to dataset via file_dataset junction table (composite FK)
		var fileLinkedToDataset bool
		err = q.QueryRowContext(context.Background(), `
			SELECT EXISTS(SELECT 1 FROM file_dataset WHERE file_id = ? AND dataset_id = ?)
		`, fileID, datasetID).Scan(&fileLinkedToDataset)
		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("failed to verify file-dataset link: %v", err),
			})
			continue
		}
		if !fileLinkedToDataset {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("file exists in cluster but is not linked to dataset %s", datasetID),
			})
			continue
		}

		// Check no existing labels for this file
		var labelCount int
		err = q.QueryRowContext(context.Background(), `
			SELECT COUNT(*) FROM label l
			JOIN segment s ON l.segment_id = s.id
			WHERE s.file_id = ? AND l.active = true
		`, fileID).Scan(&labelCount)

		if err != nil {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("failed to check existing labels: %v", err),
			})
			continue
		}

		if labelCount > 0 {
			errors = append(errors, ImportSegmentError{
				File:    filepath.Base(sf.WavPath),
				Stage:   StageValidation,
				Message: fmt.Sprintf("file already has %d label(s) - fresh imports only", labelCount),
			})
			continue
		}

		fileIDMap[fileID] = sf
	}

	return fileIDMap, errors
}

// countTotalSegments counts total segments from validated files
func countTotalSegments(fileIDMap map[string]scannedDataFile) int {
	count := 0
	for _, sf := range fileIDMap {
		count += len(sf.Segments)
	}
	return count
}