package audio

import (
	"github.com/madelynnblue/go-dsp/fft"
)

// BandpassShiftFilter applies a bandpass filter retaining frequencies between
// lowFreq and highFreq, shifts the retained band down to baseband (0 Hz),
// and downsamples to 2*(highFreq-lowFreq) Hz.
// Returns the processed samples and the new sample rate.
//
// For example, with --bandpass 8000-24000 on 250kHz audio:
//   - Bandpass keeps only 8-24kHz content
//   - Shift down by 8kHz so content is at 0-16kHz
//   - Downsample from 250kHz to 32kHz
//   - Spectrogram shows the 8-24kHz band as if it were 0-16kHz
func BandpassShiftFilter(audio []float64, sampleRate int, lowFreq, highFreq float64) ([]float64, int) {
	n := len(audio)
	if n == 0 {
		return audio, sampleRate
	}

	paddedLen := nextPowerOf2(n)
	padded := make([]float64, paddedLen)
	copy(padded, audio)

	spectrum := fft.FFTReal(padded)

	lowBin := int(lowFreq * float64(paddedLen) / float64(sampleRate))
	highBin := int(highFreq * float64(paddedLen) / float64(sampleRate))
	bandBins := highBin - lowBin

	// Shift spectrum down by lowBin: bin k gets content from original bin (k + lowBin)
	// Keep only bins within the passband, enforce conjugate symmetry for real output.
	shifted := make([]complex128, paddedLen)
	for k := 0; k <= paddedLen/2; k++ {
		if k <= bandBins {
			srcBin := k + lowBin
			if srcBin >= 0 && srcBin <= paddedLen/2 {
				shifted[k] = spectrum[srcBin]
			}
		}
		// Conjugate symmetry: shifted[N-k] = conj(shifted[k])
		if k > 0 && k < paddedLen/2 {
			shifted[paddedLen-k] = complex(real(shifted[k]), -imag(shifted[k]))
		}
	}

	filtered := fft.IFFT(shifted)

	result := make([]float64, n)
	for i := range result {
		result[i] = real(filtered[i])
	}

	// Downsample to 2 * bandwidth
	bandwidth := highFreq - lowFreq
	newRate := int(2 * bandwidth)
	if newRate >= sampleRate {
		return result, sampleRate
	}
	result = ResampleRate(result, sampleRate, newRate)
	return result, newRate
}

// nextPowerOf2 returns the smallest power of 2 >= n.
func nextPowerOf2(n int) int {
	if n <= 0 {
		return 1
	}
	n--
	n |= n >> 1
	n |= n >> 2
	n |= n >> 4
	n |= n >> 8
	n |= n >> 16
	n |= n >> 32
	return n + 1
}