package spectrogram

import (
	"fmt"
	"image"
	"math"
	"os"
	"path/filepath"
	"strings"
	"sync"

	"github.com/madelynnblue/go-dsp/window"
	"skraak/audio"
	"skraak/wav"
)

// cached Hann windows by size, computed once
var (
	hannCache   = map[int][]float64{}
	hannCacheMu sync.RWMutex
)

// getCachedHannWindow returns a cached Hann window of the given size.
func getCachedHannWindow(size int) []float64 {
	hannCacheMu.RLock()
	if w, ok := hannCache[size]; ok {
		hannCacheMu.RUnlock()
		return w
	}
	hannCacheMu.RUnlock()

	hannCacheMu.Lock()
	defer hannCacheMu.Unlock()
	// Double-check after acquiring write lock
	if w, ok := hannCache[size]; ok {
		return w
	}
	w := window.Hann(size)
	hannCache[size] = w
	return w
}

// SpectrogramConfig holds STFT parameters
type SpectrogramConfig struct {
	WindowSize int // FFT window size (e.g., 400)
	HopSize    int // Hop between windows (e.g., 200 for 50% overlap)
	SampleRate int // Sample rate in Hz
}

// DefaultSpectrogramConfig returns default config matching Julia implementation
func DefaultSpectrogramConfig(sampleRate int) SpectrogramConfig {
	return SpectrogramConfig{
		WindowSize: 512,
		HopSize:    256, // 50% overlap (window/2)
		SampleRate: sampleRate,
	}
}

// GenerateSpectrogram generates a spectrogram from audio samples.
// Returns a 2D array of uint8 (0-255) where:
// - First dimension is frequency bins (rows)
// - Second dimension is time frames (columns)
func GenerateSpectrogram(samples []float64, cfg SpectrogramConfig) [][]uint8 {
	if len(samples) < cfg.WindowSize {
		return nil
	}

	// Get cached Hann window
	hannWindow := getCachedHannWindow(cfg.WindowSize)

	// Calculate number of frames
	numFrames := (len(samples)-cfg.WindowSize)/cfg.HopSize + 1
	if numFrames <= 0 {
		return nil
	}

	// Number of frequency bins (half of FFT due to symmetry)
	numFreqBins := cfg.WindowSize/2 + 1

	// Allocate power spectrum as flat backing slice (single allocation)
	powerFlat := make([]float64, numFreqBins*numFrames)

	// Pre-allocate scratch buffers (reused across all frames — zero allocs in loop)
	frameData := make([]float64, cfg.WindowSize)
	scratch := make([]complex128, cfg.WindowSize)
	framePower := make([]float64, numFreqBins)

	// Perform STFT
	for frame := range numFrames {
		start := frame * cfg.HopSize

		// Extract and window the frame
		for i := 0; i < cfg.WindowSize; i++ {
			frameData[i] = samples[start+i] * hannWindow[i]
		}

		// Compute power spectrum via inline FFT (zero allocations)
		audio.PowerSpectrumFFT(frameData, framePower, scratch)

		// Copy power into flat matrix (freq bins x time frames layout)
		for bin := range numFreqBins {
			powerFlat[bin*numFrames+frame] = framePower[bin]
		}
	}

	// Fused normalization: replace zeros, convert to dB, find min/max, normalize to uint8
	// All in 2 passes instead of 6
	return normalizeFlat(powerFlat, numFreqBins, numFrames)
}

// normalizeFlat converts power values to dB, normalizes to 0-255, in 2 passes.
// Operates on a flat slice laid out as [row0_col0, row0_col1, ..., row1_col0, ...].
// Returns [][]uint8 with rows flipped vertically (low frequencies at bottom).
// convertToDB replaces power values with dB values in-place, returning min/max dB.
// Zero/negative values are clamped to minNonZero before conversion.
func convertToDB(power []float64) (minDB, maxDB float64) {
	minNonZero := math.MaxFloat64
	for _, val := range power {
		if val > 0 && val < minNonZero {
			minNonZero = val
		}
	}
	if minNonZero == math.MaxFloat64 {
		minNonZero = 1e-20
	}

	minDB = math.MaxFloat64
	maxDB = -math.MaxFloat64
	for i, val := range power {
		if val <= 0 {
			val = minNonZero
		}
		db := 10.0 * math.Log10(val)
		power[i] = db
		if db < minDB {
			minDB = db
		}
		if db > maxDB {
			maxDB = db
		}
	}
	return minDB, maxDB
}

func normalizeFlat(power []float64, rows, cols int) [][]uint8 {
	if rows == 0 || cols == 0 {
		return nil
	}

	minDB, maxDB := convertToDB(power)

	// Normalize dB to uint8 and write into result (with vertical flip)
	rangeDB := maxDB - minDB
	if rangeDB == 0 {
		rangeDB = 1
	}
	scale := 255.0 / rangeDB

	resultFlat := make([]uint8, rows*cols)
	result := make([][]uint8, rows)
	for i := range result {
		srcRow := rows - 1 - i
		result[i] = resultFlat[i*cols : (i+1)*cols]
		srcOff := srcRow * cols
		for j := range cols {
			result[i][j] = uint8((power[srcOff+j] - minDB) * scale)
		}
	}

	return result
}

// ExtractSegmentSamples extracts samples from a time range
func ExtractSegmentSamples(samples []float64, sampleRate int, startSec, endSec float64) []float64 {
	startIdx := int(startSec * float64(sampleRate))
	endIdx := int(endSec * float64(sampleRate))

	if startIdx < 0 {
		startIdx = 0
	}
	if endIdx > len(samples) {
		endIdx = len(samples)
	}
	if startIdx >= endIdx {
		return nil
	}

	return samples[startIdx:endIdx]
}

// SpectrogramImageFromSamples generates a spectrogram image from audio samples.
// This is the core pipeline: spectrogram -> colormap/grayscale -> image -> resize.
// Use this when you already have samples (e.g., after bandpass filtering).
func SpectrogramImageFromSamples(samples []float64, sampleRate int, color bool, imgSize int) image.Image {
	if len(samples) == 0 {
		return nil
	}

	config := DefaultSpectrogramConfig(sampleRate)
	spectrogram := GenerateSpectrogram(samples, config)
	if spectrogram == nil {
		return nil
	}

	var img image.Image
	if color {
		colorData := ApplyL4Colormap(spectrogram)
		img = CreateRGBImage(colorData)
	} else {
		img = CreateGrayscaleImage(spectrogram)
	}
	if img == nil {
		return nil
	}

	imgSize = ClampImageSize(imgSize)
	return ResizeImage(img, imgSize, imgSize)
}

// GenerateSegmentSpectrogram generates a spectrogram image for a time segment.
// Handles WAV loading, downsampling, and image creation.
// color=true applies L4 colormap, color=false creates grayscale.
// imgSize specifies the output image dimensions (clamped to [224, 896]).
func GenerateSegmentSpectrogram(dataFilePath string, startTime, endTime float64, color bool, imgSize int) (image.Image, error) {
	// Derive WAV file path (strip .data suffix)
	wavPath := strings.TrimSuffix(dataFilePath, ".data")

	// Read only the requested segment's samples from the WAV file
	segSamples, sampleRate, err := wav.ReadWAVSegmentSamples(wavPath, startTime, endTime)
	if err != nil {
		return nil, err
	}

	if len(segSamples) == 0 {
		return nil, nil
	}

	// For spectrograms, downsample if sample rate exceeds 16kHz
	if sampleRate > audio.DefaultMaxSampleRate {
		segSamples = audio.ResampleRate(segSamples, sampleRate, audio.DefaultMaxSampleRate)
		sampleRate = audio.DefaultMaxSampleRate
	}

	img := SpectrogramImageFromSamples(segSamples, sampleRate, color, imgSize)
	return img, nil
}

// WritePNGFile writes an image to a PNG file. Uses O_EXCL to atomically fail
// if the file already exists. Returns an error with path context on failure.
func WritePNGFile(path string, img image.Image) error {
	file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
	if err != nil {
		if os.IsExist(err) {
			return fmt.Errorf("file already exists: %s", path)
		}
		return fmt.Errorf("failed to create PNG: %w", err)
	}
	if err := WritePNG(img, file); err != nil {
		_ = file.Close()
		return fmt.Errorf("failed to write PNG: %w", err)
	}
	if err := file.Close(); err != nil {
		return fmt.Errorf("failed to close PNG: %w", err)
	}
	return nil
}

// ClipBaseName generates the base filename for a clip in the format:
// prefix_basename_startTime_endTime
// Times are integers (floor for start, ceil for end).
func ClipBaseName(prefix, basename string, startTime, endTime float64) string {
	startInt := int(math.Floor(startTime))
	endInt := int(math.Ceil(endTime))
	return fmt.Sprintf("%s_%s_%d_%d", prefix, basename, startInt, endInt)
}

// ClipPaths returns full PNG and WAV paths for a clip in the given output directory.
// Also checks that neither file exists. Returns an error if files exist.
func ClipPaths(outputDir, prefix, basename string, startTime, endTime float64) (pngPath, wavPath string, err error) {
	baseName := ClipBaseName(prefix, basename, startTime, endTime)
	pngPath = filepath.Join(outputDir, baseName+".png")
	wavPath = filepath.Join(outputDir, baseName+".wav")

	if _, err := os.Stat(pngPath); err == nil {
		return "", "", fmt.Errorf("file already exists: %s", pngPath)
	}
	if _, err := os.Stat(wavPath); err == nil {
		return "", "", fmt.Errorf("file already exists: %s", wavPath)
	}
	return pngPath, wavPath, nil
}

// WAVBasename extracts the base filename from a .data file path.
// E.g., "/path/to/file.wav.data" -> "file".
func WAVBasename(dataFilePath string) string {
	wavPath := strings.TrimSuffix(dataFilePath, ".data")
	basename := filepath.Base(wavPath)
	return strings.TrimSuffix(basename, filepath.Ext(basename))
}