package audio

import (
	"math"
	"testing"
)

func TestBandpassShiftFilter_BasicShift(t *testing.T) {
	// Generate a signal with a tone at 10000 Hz, sample rate 48000
	sampleRate := 48000
	duration := 0.05
	numSamples := int(float64(sampleRate) * duration)

	audio := make([]float64, numSamples)
	for i := range audio {
		ts := float64(i) / float64(sampleRate)
		audio[i] = math.Sin(2 * math.Pi * 10000 * ts)
	}

	// Bandpass 8000-12000, shift to baseband
	// The 10000 Hz tone should shift to 2000 Hz (10000 - 8000)
	filtered, newRate := BandpassShiftFilter(audio, sampleRate, 8000, 12000)

	bandwidth := 12000 - 8000     // 4000 Hz
	expectedRate := 2 * bandwidth // 8000 Hz
	if newRate != expectedRate {
		t.Errorf("newRate = %d, want %d", newRate, expectedRate)
	}

	// Check that the shifted tone is at 2000 Hz in the filtered signal
	power := computeBinPowerAtFreq(filtered, newRate, 2000)
	totalPower := signalPower(filtered)
	if totalPower == 0 {
		t.Fatal("filtered signal has no power")
	}
	// The 2000 Hz bin should contain significant energy
	ratio := power / totalPower
	if ratio < 0.1 {
		t.Errorf("shifted 10kHz→2kHz bin has only %.1f%% of total power, want > 10%%", ratio*100)
	}
}

func TestBandpassShiftFilter_RemovesOutOfBand(t *testing.T) {
	// Generate a signal with tones at 1000 Hz (out of band) and 10000 Hz (in band)
	sampleRate := 48000
	duration := 0.05
	numSamples := int(float64(sampleRate) * duration)

	audio := make([]float64, numSamples)
	for i := range audio {
		ts := float64(i) / float64(sampleRate)
		audio[i] = math.Sin(2*math.Pi*1000*ts) + math.Sin(2*math.Pi*10000*ts)
	}

	filtered, newRate := BandpassShiftFilter(audio, sampleRate, 8000, 12000)

	// The 1000 Hz tone (out of band) should be removed
	// After shift, in-band 10kHz→2kHz, out-of-band 1kHz→-7kHz (removed)
	inBandPower := computeBinPowerAtFreq(filtered, newRate, 2000)
	outBandPower := computeBinPowerAtFreq(filtered, newRate, 7000) // 1kHz shifted = not present

	if outBandPower > inBandPower*0.1 {
		t.Errorf("out-of-band leakage: outBand=%.6f, inBand=%.6f", outBandPower, inBandPower)
	}
}

func TestBandpassShiftFilter_EmptyInput(t *testing.T) {
	result, rate := BandpassShiftFilter([]float64{}, 48000, 1000, 8000)
	if len(result) != 0 {
		t.Errorf("expected empty result for empty input, got %d samples", len(result))
	}
	if rate != 48000 {
		t.Errorf("rate = %d, want 48000", rate)
	}
}

func TestBandpassShiftFilter_DownsampleRate(t *testing.T) {
	// 250kHz audio, bandpass 8000-24000 → bandwidth 16000 → newRate 32000
	audio := make([]float64, 2500) // tiny signal
	filtered, newRate := BandpassShiftFilter(audio, 250000, 8000, 24000)

	expectedRate := 32000
	if newRate != expectedRate {
		t.Errorf("newRate = %d, want %d", newRate, expectedRate)
	}

	// Output should be downsampled: 2500 samples at 250kHz → ~320 samples at 32kHz
	expectedSamples := int(float64(len(audio)) * float64(expectedRate) / 250000)
	if absInt(len(filtered)-expectedSamples) > 2 {
		t.Errorf("output samples = %d, want ~%d", len(filtered), expectedSamples)
	}
}

func TestBandpassShiftFilter_NarrowBand(t *testing.T) {
	// Bat vocalisations: bandpass 40000-56000 → bandwidth 16000 → newRate 32000
	audio := make([]float64, 5000)
	sampleRate := 250000
	for i := range audio {
		ts := float64(i) / float64(sampleRate)
		audio[i] = math.Sin(2 * math.Pi * 48000 * ts)
	}

	filtered, newRate := BandpassShiftFilter(audio, sampleRate, 40000, 56000)
	if newRate != 32000 {
		t.Errorf("newRate = %d, want 32000", newRate)
	}

	// 48kHz tone shifted to 8kHz (48000 - 40000)
	power := computeBinPowerAtFreq(filtered, newRate, 8000)
	totalPower := signalPower(filtered)
	if totalPower == 0 {
		t.Fatal("filtered signal has no power")
	}
	ratio := power / totalPower
	if ratio < 0.05 {
		t.Errorf("shifted 48kHz→8kHz bin has only %.1f%% of total power, want > 5%%", ratio*100)
	}
}

func TestNextPowerOf2(t *testing.T) {
	tests := []struct {
		input, want int
	}{
		{0, 1},
		{1, 1},
		{2, 2},
		{3, 4},
		{5, 8},
		{100, 128},
		{1024, 1024},
		{1025, 2048},
	}
	for _, tt := range tests {
		got := nextPowerOf2(tt.input)
		if got != tt.want {
			t.Errorf("nextPowerOf2(%d) = %d, want %d", tt.input, got, tt.want)
		}
	}
}

// computeBinPowerAtFreq uses our FFT to compute power at a specific frequency.
func computeBinPowerAtFreq(samples []float64, sampleRate int, freqHz float64) float64 {
	n := nextPowerOf2(len(samples))
	padded := make([]float64, n)
	copy(padded, samples)

	power := make([]float64, n/2+1)
	scratch := make([]complex128, n)
	PowerSpectrumFFT(padded, power, scratch)

	bin := int(freqHz * float64(n) / float64(sampleRate))
	if bin >= len(power) {
		return 0
	}
	return power[bin]
}

// signalPower returns the total power of a signal.
func signalPower(samples []float64) float64 {
	power := 0.0
	for _, s := range samples {
		power += s * s
	}
	return power
}

// absInt returns the absolute value of an int.
func absInt(x int) int {
	if x < 0 {
		return -x
	}
	return x
}

func BenchmarkBandpassShiftFilter(b *testing.B) {
	sampleRate := 250000
	numSamples := int(float64(sampleRate) * 5.0) // 5 seconds

	audio := make([]float64, numSamples)
	for i := range audio {
		ts := float64(i) / float64(sampleRate)
		audio[i] = math.Sin(2*math.Pi*1000*ts) + math.Sin(2*math.Pi*15000*ts)
	}

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		BandpassShiftFilter(audio, sampleRate, 8000, 24000)
	}
}