A rust port of vgm2mid by Paul Jensen and Valley Bell
use crate::config::Config;
use crate::midi_shim::{
	db_to_midi_vol, MIDIShim, MIDI_PAN, MIDI_PAN_CENTER, MIDI_PAN_LEFT, MIDI_PAN_RIGHT,
	MIDI_VOLUME, MIDI_VOLUME_MAX,
};
use crate::strict;
use crate::utils::{hz_to_note, midi_from_note, shift, FactoredState};
use anyhow::{bail, Result};
use midly::num::u4;

pub(crate) const SN76489_LATCH: u8 = 0x80; //128
pub(crate) const SN76489_CHANNEL_SELECT: u8 = SN76489_LATCH | 0x70; //240
#[allow(clippy::identity_op)]
pub(crate) const SN76489_TONE_1: u8 = SN76489_LATCH | 0x00; //128
pub(crate) const SN76489_TONE_2: u8 = SN76489_LATCH | 0x20; //160
pub(crate) const SN76489_TONE_3: u8 = SN76489_LATCH | 0x40; //192
pub(crate) const SN76489_NOISE: u8 = SN76489_LATCH | 0x60; //224

const SN76489_CHN_MASK: u8 = 0x60; //112

#[allow(dead_code)]
const SN76489_FB_PERIODIC: u8 = 0x00;
#[allow(dead_code)]
const SN76489_FB_WHITE: u8 = 0x01;

#[allow(dead_code)]
const SN76489_CLOCK_SOURCE_HALF: u8 = 0x00;
#[allow(dead_code)]
const SN76489_CLOCK_SOURCE_FOURTH: u8 = 0x01;
#[allow(dead_code)]
const SN76489_CLOCK_SOURCE_EIGHTH: u8 = 0x02;
#[allow(dead_code)]
const SN76489_CLOCK_SOURCE_TONE_3: u8 = 0x03;

const SN76489_ATTENUATOR_1: u8 = SN76489_LATCH | 0x10; //144
const SN76489_ATTENUATOR_2: u8 = SN76489_LATCH | 0x30; //176
const SN76489_ATTENUATOR_3: u8 = SN76489_LATCH | 0x50; //208
const SN76489_ATTENUATOR_NOISE: u8 = SN76489_LATCH | 0x70; //240
#[allow(dead_code)]
const SN76489_ATTENUATION_MIN: u8 = 0x00;
const SN76489_ATTENUATION_MAX: u8 = 0x0F;

pub(crate) const MIDI_CHANNEL_SN76489_BASE: u4 = u4::new(0x0A);
// used for Note On/Off-Detection
//Public SN76489_LastVol(0x0..=0x7) As Byte

pub(crate) struct SN76489State {
	feedback: [u8; 2],
	clock_source: [u8; 2],
	attenuation_1: [u8; 8],
	attenuation_2: [u8; 8],
	gg_pan: [[u8; 5]; 2],
	chip_num: u8,
	pub note_delay: [u32; 8],
	factored: FactoredState,
	last_command: [u8; 2],
}

impl Default for SN76489State {
	fn default() -> Self {
		SN76489State {
			feedback: [0; 2],
			clock_source: [0; 2],
			attenuation_1: [0; 8],
			attenuation_2: [0xFF; 8],
			gg_pan: [[0, 0, 0, 0, 0xFF]; 2],
			chip_num: 0,
			note_delay: [0; 8],
			factored: Default::default(),
			last_command: [0x00; 2],
		}
	}
}

pub(crate) struct SN76489<'config> {
	pub(crate) state: SN76489State,
	t6w28: bool, //FIXME: dumb naming. This should probably be refactored into its own class anyway.
	clock: u32,
	config: &'config Config,
}

impl<'config> SN76489<'config> {
	//WARNING: This won't behave well if fnum is small.
	fn hz(fnum: u32, clock: u32) -> f64 {
		if fnum == 0 {
			0.0
		} else {
			(clock as f64 / 32.0) / fnum as f64
		}
	}

	pub(crate) fn new<'c: 'config>(
		t6w28: bool,
		clock: u32,
		config: &'c Config,
		opt_state: Option<SN76489State>,
	) -> Self {
		SN76489 {
			state: opt_state.unwrap_or_default(),
			t6w28,
			clock,
			config,
		}
	}

	pub(crate) fn command_handle(
		&mut self,
		data: u8,
		chip_num: u8,
		midi: &mut MIDIShim,
	) -> Result<()> {
		self.state.chip_num = validate_chip_number(chip_num)?;

		if data & 0x80 != 0 {
			let ret = match data & SN76489_CHANNEL_SELECT {
				// skip and wait for data command
				SN76489_TONE_1 | SN76489_TONE_2 | SN76489_TONE_3 => Ok(()),
				_ => self.command_handle_internal(data, data, midi), // Volume Changes use LSB
			};
			self.state.last_command[self.state.chip_num as usize] = data;
			ret
		} else {
			// do data write
			self.command_handle_internal(
				self.state.last_command[self.state.chip_num as usize],
				data,
				midi,
			)
		}
	}

	fn command_handle_internal(&mut self, msb: u8, lsb: u8, midi: &mut MIDIShim) -> Result<()> {
		if msb & 0x80 == 0x0 {
			return Ok(());
		}

		let mut channel = (msb & SN76489_CHN_MASK) / 0x20;

		match channel.cmp(&0x03) {
			std::cmp::Ordering::Less
				if self.config.sn76489_ch_disabled[channel as usize] =>
			{
				return Ok(())
			},
			std::cmp::Ordering::Equal if self.config.sn76489_noise_disabled => {
				return Ok(())
			},
			std::cmp::Ordering::Greater => bail!("Invalid midi channel"),
			_ => (),
		}

		channel |= self.state.chip_num << 2;
		let mut midi_channel = MIDI_CHANNEL_SN76489_BASE + (channel & 0x03).into();
		if self.state.chip_num == 0x01 {
			midi_channel -= MIDI_CHANNEL_SN76489_BASE
		}

		match msb & SN76489_CHANNEL_SELECT {
			SN76489_TONE_1 | SN76489_TONE_2 | SN76489_TONE_3 => self
				.process_tone_channels(
					msb,
					lsb,
					channel.into(),
					midi_channel,
					midi,
				)?,
			SN76489_NOISE => self.process_noise_channel(lsb, channel.into()),
			SN76489_ATTENUATOR_1 | SN76489_ATTENUATOR_2 | SN76489_ATTENUATOR_3 => self
				.process_tone_attenuation(channel.into(), lsb, midi_channel, midi)?,
			SN76489_ATTENUATOR_NOISE => self.process_noise_attenuation(
				channel.into(),
				lsb,
				midi_channel,
				midi,
			),
			_ => strict!("Invalid channel selected"),
		}
		Ok(())
	}

	fn process_tone_channels(
		&mut self,
		msb: u8,
		lsb: u8,
		channel_ptr: usize,
		midi_channel: u4,
		midi: &mut MIDIShim,
	) -> Result<()> {
		self.state.factored.fnum_1[channel_ptr] = shift(
			&mut self.state.factored.fnum_2[channel_ptr],
			((lsb as u32 & 0x3F) << 4) + (msb & 0x0F) as u32,
		);

		self.state.factored.hz_1[channel_ptr] = shift(
			&mut self.state.factored.hz_2[channel_ptr],
			Self::hz(self.state.factored.fnum_2[channel_ptr], self.clock),
		);

		if !(self.t6w28 && self.state.chip_num > 0x00) {
			let mut temp_note = hz_to_note(self.state.factored.hz_2[channel_ptr]);

			if temp_note >= 0x80.into() {
				//TempNote = 0x7F - self.state.factored.fnum_2[channel]
				temp_note = 0xFF.into();
			}

			self.state.factored.note_1[channel_ptr] =
				shift(&mut self.state.factored.note_2[channel_ptr], temp_note);
			if self.t6w28 {
				self.state.factored.note_1[0x04 + channel_ptr] = shift(
					&mut self.state.factored.note_2[0x04 + channel_ptr],
					temp_note,
				);
			}
		} else {
			//if T6W28_SN76489 & SN76489_NUM > 0x0{
			let mut temp_note = hz_to_note(self.state.factored.hz_2[channel_ptr] / 2.0);
			if (channel_ptr & 0x03) == 0x02 {
				if temp_note < 0xFF.into() {
					temp_note -= 24.0; // - 2 Octaves
					if temp_note >= 0x80.into() {
						temp_note = 0x7F.into();
					}
				}
				self.state.factored.note_2[0x3] = temp_note;
				self.state.factored.note_2[0x7] = temp_note;
			}
			return Ok(());
		}

		if self.state.factored.fnum_2[channel_ptr] == 0 {
			if self.state.factored.note_on_2[channel_ptr] {
				if self.state.factored.midi_note[channel_ptr] < 0xFF {
					midi.do_note_on(
						self.state.factored.note_1[channel_ptr],
						self.state.factored.note_2[channel_ptr],
						midi_channel,
						&mut self.state.factored.midi_note[channel_ptr],
						&mut self.state.factored.midi_wheel[channel_ptr],
						None,
						None,
					)?;
					self.state.factored.midi_note[channel_ptr] = 0xFF.into(); // TODO: Check this
				}

				self.state.factored.note_on_1[channel_ptr] =
					shift(&mut self.state.factored.note_on_2[channel_ptr], false);
			}
		} else if self.state.factored.fnum_2[channel_ptr] > 0
			&& self.state.factored.fnum_2[channel_ptr]
				!= self.state.factored.fnum_1[channel_ptr]
		{
			let dn_ret = midi.do_note_on(
				self.state.factored.note_1[channel_ptr],
				self.state.factored.note_2[channel_ptr],
				midi_channel,
				&mut self.state.factored.midi_note[channel_ptr],
				&mut self.state.factored.midi_wheel[channel_ptr],
				Some(if self.state.factored.note_on_2[channel_ptr] {
					0x0
				} else {
					0xFF
				}),
				None,
			)?;
			self.state.factored.note_on_1[channel_ptr] =
				shift(&mut self.state.factored.note_on_2[channel_ptr], true);
			if dn_ret {
				self.state.note_delay[channel_ptr] = 0x0;
			}
		}
		if self.t6w28 {
			self.tone_t6w28(channel_ptr, midi_channel, midi)?;
		}

		if (channel_ptr & 0x3) == 0x2
			&& self.state.clock_source[self.state.chip_num as usize] == 0x3
		{
			self.command_handle_internal(
				0xE0,
				(self.state.feedback[self.state.chip_num as usize] << 2)
					| self.state.clock_source[self.state.chip_num as usize],
				midi,
			)?;
		}
		Ok(())
	}

	fn tone_t6w28(&mut self, mut channel: usize, mut midi_channel: u4, midi: &mut MIDIShim) -> Result<()> {
		channel += 0x04;
		midi_channel -= MIDI_CHANNEL_SN76489_BASE;
		if self.state.factored.note_2[channel] == 0xFF.into() {
			if self.state.factored.note_on_2[channel] {
				if self.state.factored.midi_note[channel] < 0xFF {
					midi.do_note_on(
						self.state.factored.note_1[channel],
						self.state.factored.note_2[channel],
						midi_channel,
						&mut self.state.factored.midi_note[channel],
						&mut self.state.factored.midi_wheel[channel],
						None,
						None,
					)?;
					//self.state.factored.midi_note[channel] = 0xFF
				}
				self.state.factored.note_on_1[channel] =
					shift(&mut self.state.factored.note_on_2[channel], false);
			}
		} else if self.state.factored.note_2[channel] < 0xFF.into()
			&& self.state.factored.note_2[channel]
				!= self.state.factored.note_1[channel]
		{
			let dn_ret = midi.do_note_on(
				self.state.factored.note_1[channel],
				self.state.factored.note_2[channel],
				midi_channel,
				&mut self.state.factored.midi_note[channel],
				&mut self.state.factored.midi_wheel[channel],
				Some(if self.state.factored.note_on_2[channel] {
					0x0
				} else {
					0xFF
				}),
				None,
			)?;
			self.state.factored.note_on_1[channel] =
				shift(&mut self.state.factored.note_on_2[channel], true);
			if dn_ret {
				self.state.note_delay[channel] = 0x00;
			}
		}
		Ok(())
	}

	fn process_tone_attenuation(
		&mut self,
		channel: usize,
		lsb: u8,
		midi_channel: u4,
		midi: &mut MIDIShim,
	) -> Result<()> {
		self.state.attenuation_1[channel] = shift(
			&mut self.state.attenuation_2[channel],
			lsb & SN76489_ATTENUATION_MAX,
		);

		if self.state.attenuation_2[channel] != self.state.attenuation_1[channel] {
			//self.state.attenuation_2[channel] = (LSB & SN76489_ATTENUATION_MAX) * 8.45 //(127 / 15)
			self.state.attenuation_2[channel] = lsb & SN76489_ATTENUATION_MAX; // I like round values
								   //self.state.factored.midi_volume[0] = MIDI_VOLUME_MAX + 1 - 0x8 - self.state.attenuation_2[channel] * 0x8
								   //if self.state.factored.midi_volume[0] > 0x7F{ self.state.factored.midi_volume[0] = 0x7F
			self.state.factored.midi_volume[0] =
				db_to_midi_vol(Self::vol_to_db(lsb & 0x0F));
			midi.controller_write(
				midi_channel,
				MIDI_VOLUME,
				self.state.factored.midi_volume[0],
			);

			// write Note On/Off
			if self.config.sn76489_voldep_notes >= 1
				&& self.state.factored.note_on_2[channel]
			{
				if self.state.factored.midi_volume[0] == 0 {
					midi.do_note_on(
						self.state.factored.note_1[channel],
						0xFF.into(),
						midi_channel,
						&mut self.state.factored.midi_note[channel],
						&mut self.state.factored.midi_wheel[channel],
						Some(0xFF),
						None,
					)?;
					self.state.note_delay[channel] = 44100;
				//} else if (SN76489_LastVol[channel] = 0x0 & SN76489_NoteDelay[channel] >= 10) | _
				//	(self.config.sn76489_voldep_notes >= 0x2 & SN76489_NoteDelay[channel] >= 735 &
				//	SN76489_LastVol[channel] + 20 < self.state.factored.midi_volume[0]) {
				} else if (self.state.attenuation_1[channel] == 0x0F
					&& self.state.note_delay[channel] >= 10) || (self
					.config
					.sn76489_voldep_notes
					>= 0x2
					&& self.state.note_delay[channel] >= 735
					&& self.state.attenuation_1[channel] - 2
						> self.state.attenuation_2[channel])
				{
					midi.do_note_on(
						self.state.factored.note_1[channel],
						self.state.factored.note_2[channel],
						midi_channel,
						&mut self.state.factored.midi_note[channel],
						&mut self.state.factored.midi_wheel[channel],
						Some(0xFF),
						None,
					)?;
					self.state.note_delay[channel] = 0;
				}
			}
			//SN76489_LastVol[channel] = self.state.factored.midi_volume[0]
		}
		Ok(())
	}

	fn process_noise_attenuation(
		&mut self,
		channel_ptr: usize,
		lsb: u8,
		midi_channel: u4,
		midi: &mut MIDIShim,
	) {
		self.state.attenuation_1[channel_ptr] = shift(
			&mut self.state.attenuation_2[channel_ptr],
			lsb & SN76489_ATTENUATION_MAX,
		);
		self.state.factored.midi_volume[0] = self.calculate_noise_volume(channel_ptr);
		if self.state.attenuation_2[channel_ptr]
			!= self.state.attenuation_1[channel_ptr]
			|| self.state.note_delay[channel_ptr] >= 735
		{
			if self.state.factored.midi_volume[0] > 0 {
				// old Note-Height: 39
				if self.state.factored.note_1[channel_ptr] < 0xFF.into() {
					midi.note_off_write(
						midi_channel,
						midi_from_note(
							self.state.factored.note_1
								[channel_ptr],
						),
						0x00.into(),
					);
				}
				if self.state.factored.note_2[channel_ptr] < 0xFF.into() {
					midi.note_on_write(
						midi_channel,
						midi_from_note(
							self.state.factored.note_2
								[channel_ptr],
						),
						self.state.factored.midi_volume[0],
					);
				}
				self.state.factored.note_1[channel_ptr] =
					self.state.factored.note_2[channel_ptr];
			} else if self.state.factored.midi_volume[0] == 0 && self.state.attenuation_1[channel_ptr] < 0x0F && self.state.factored.note_1[channel_ptr]
						< 0xFF.into() {
   					midi.note_off_write(
   						midi_channel,
   						midi_from_note(
   							self.state.factored.note_1
   								[channel_ptr],
   						),
   						0x00.into(),
   					);
   					self.state.factored.note_1[channel_ptr] = 0xFF.into();
   				}
			self.state.note_delay[channel_ptr] = 0x0;
		}
	}

	fn calculate_noise_volume(&mut self, channel_ptr: usize) -> midly::num::u7 {
		(MIDI_VOLUME_MAX.as_int() + 1 - 8 - (self.state.attenuation_2[channel_ptr] * 8)).into()
	}

	fn process_noise_channel(&mut self, lsb: u8, channel_ptr: usize) {
		self.state.feedback[self.state.chip_num as usize] = (lsb & 0x04) >> 2; //FIXME: Is this supposed to be a boolean?
		self.state.clock_source[self.state.chip_num as usize] = lsb & 0x03;

		// Noise-Frequency
		let noise_frequency = self.calculate_noise_frequency(channel_ptr);

		self.state.factored.fnum_1[channel_ptr] =
			shift(&mut self.state.factored.fnum_2[channel_ptr], noise_frequency);

		if self.state.factored.fnum_2[channel_ptr] == 0 {
			self.state.factored.fnum_2[channel_ptr] = 1
		}

		self.state.factored.hz_1[channel_ptr] = shift(
			&mut self.state.factored.hz_2[channel_ptr],
			Self::hz(self.state.factored.fnum_2[channel_ptr], self.clock),
		);

		//if T6W28_SN76489 & SN76489_NUM > 0x0{ return }

		self.state.factored.note_2[channel_ptr] =
			hz_to_note(self.state.factored.hz_2[channel_ptr]) / 1.5;
		//SN76489_NoteDelay[channel] = 0x0
	}

	fn calculate_noise_frequency(&mut self, channel: usize) -> u32 {
		if self.state.clock_source[self.state.chip_num as usize] == 0x3 {
			self.state.factored.fnum_2[channel - 1] << 1
		} else {
			1 << (5 + self.state.clock_source[self.state.chip_num as usize] as u32)
		}
	}

	pub(crate) fn gg_stereo_handle(
		&mut self,
		register: u8,
		chip_num: u8,
		midi: &mut MIDIShim,
	) -> Result<()> {
		self.state.chip_num = validate_chip_number(chip_num)?;

		let mut pan_val: u8;
		let mut channel: u4;

		for cur_bit in 0x0..=0x3 {
			let channel_mask = 1 << cur_bit;

			pan_val = 0x0;
			if (register & (channel_mask << 4)) != 0 {
				pan_val |= 0x01; // Left Channel On
			}
			if (register & channel_mask) != 0 {
				pan_val |= 0x02; // Right Channel On
			}

			if self.state.gg_pan[self.state.chip_num as usize][cur_bit as usize]
				!= pan_val || self.state.gg_pan[self.state.chip_num as usize][0x4]
				== register
			{
				//if CurBit = 0x0{
				//	CH = CHN_DAC
				//} else {
				channel = MIDI_CHANNEL_SN76489_BASE + cur_bit.into();
				if self.state.chip_num == 0x1 {
					channel -= MIDI_CHANNEL_SN76489_BASE
				}
				//}
				match pan_val {
					0x1 => midi.controller_write(
						channel,
						MIDI_PAN,
						MIDI_PAN_LEFT,
					),
					0x2 => midi.controller_write(
						channel,
						MIDI_PAN,
						MIDI_PAN_RIGHT,
					),
					0x3 => midi.controller_write(
						channel,
						MIDI_PAN,
						MIDI_PAN_CENTER,
					),
					_ => strict!("Illegal pan_val"),
				};
			}
			self.state.gg_pan[self.state.chip_num as usize][cur_bit as usize] = pan_val
		}

		self.state.gg_pan[self.state.chip_num as usize][0x4] = register;
		Ok(())
	}

	pub(crate) fn vol_to_db(tl: u8) -> f64 {
		if tl < 0x0F {
			-f64::from(tl) * 2.0
		} else {
			-400.0 // results in volume 0
		}
	}
}

fn validate_chip_number(num: u8) -> Result<u8> {
	match num {
		0x00 => Ok(0x00),
		0x01 => Ok(0x01),
		_ => bail!("Illegal chip number"),
	}
}