A rust port of vgm2mid by Paul Jensen and Valley Bell
use crate::config::Config;
use crate::midi_shim::{db_to_midi_vol, lin_to_db, MIDIShim, MIDI_BANK_SELECT, MIDI_VOLUME};
use crate::sn76489::MIDI_CHANNEL_SN76489_BASE;
use crate::strict;
use crate::utils::shift;
use crate::utils::{hz_to_note, FactoredState};
use crate::vgm2mid::CHN_DAC;
use anyhow::Result;

use midly::MidiMessage::PitchBend;
use midly::TrackEventKind::Midi;

const APU_WRA0: u8 = 0x0;
const APU_WRA1: u8 = 0x1;
const APU_WRA2: u8 = 0x2;
const APU_WRA3: u8 = 0x3;
const APU_WRB0: u8 = 0x4;
const APU_WRB1: u8 = 0x5;
const APU_WRB2: u8 = 0x6;
const APU_WRB3: u8 = 0x7;
const APU_WRC0: u8 = 0x8;
const APU_WRC2: u8 = 0xA;
const APU_WRC3: u8 = 0xB;
const APU_WRD0: u8 = 0xC;
const APU_WRD2: u8 = 0xE;
const APU_WRD3: u8 = 0xF;
const APU_WRE0: u8 = 0x10;
const APU_WRE1: u8 = 0x11;
const APU_WRE2: u8 = 0x12;
#[allow(dead_code)]
const APU_WRE3: u8 = 0x13;
const APU_SMASK: u8 = 0x15;
#[allow(dead_code)]
const APU_IRQCTRL: u8 = 0x17;

// N2A03 clock: 21 477 270 / 12 = 1 789 772.5
//N2A03 clock / 16
const NES_CLK_BASE: f64 = 111860.78125;

fn hz_nes(fnum: u32) -> f64 {
	NES_CLK_BASE / f64::from(fnum + 1)
}

fn hz_nesnoise(freq_mode: u32) -> f64 {
	let fnum = [
		4, 8, 16, 32, 64, 96, 128, 160, 202, 254, 380, 508, 762, 1016, 2034, 2046,
	][(1 + freq_mode) as usize];
	NES_CLK_BASE / f64::from(fnum + 1)
}

pub(crate) struct NESAPUState {
	envelope_1: [u8; 4],
	envelope_2: [u8; 4],
	vblen_1: [u8; 4],
	vblen_2: [u8; 4],
	hold: [u8; 4],
	tri_len: u8,
	duty_1: [u8; 4],
	duty_2: [u8; 4],
	note_en_1: [u8; 4],
	note_en_2: [u8; 4],
	note_delay: [u32; 8],
	factored: FactoredState,
}

impl Default for NESAPUState {
	fn default() -> Self {
		NESAPUState {
			envelope_1: [0; 4],
			envelope_2: [0xFF; 4],
			vblen_1: [0; 4],
			vblen_2: [0; 4],
			hold: [0; 4],
			tri_len: 0,
			duty_1: [0; 4],
			duty_2: [0xFF; 4],
			note_en_1: [0; 4],
			note_en_2: [0, 0, 4, 0],
			note_delay: [0; 8],
			factored: Default::default(),
		}
	}
}

pub(crate) struct NESAPU<'config> {
	state: NESAPUState,
	config: &'config Config,
}

impl<'config> NESAPU<'config> {
	pub(crate) fn new<'c: 'config>(config: &'c Config, opt_state: Option<NESAPUState>) -> Self {
		Self {
			state: opt_state.unwrap_or_default(),
			config,
		}
	}

	pub(crate) fn init(&mut self, midi: &mut MIDIShim) {
		midi.controller_write(
			MIDI_CHANNEL_SN76489_BASE + 0x02.into(),
			MIDI_BANK_SELECT,
			0x08.into(),
		);
		midi.program_change_write(MIDI_CHANNEL_SN76489_BASE + 0x02.into(), 0x50.into()); //FIXME: Magic values.
		midi.program_change_write(MIDI_CHANNEL_SN76489_BASE + 0x03.into(), 0x7F.into());
	}

	pub(crate) fn command_handle(
		&mut self,
		register: u8,
		data: u8,
		midi: &mut MIDIShim,
	) -> Result<()> {
		let channel = register / 0x4;
		if channel < 0x3 {
			if self.config.sn76489_ch_disabled[channel as usize] {
				return Ok(());
			}
		} else if channel == 0x3 {
			if self.config.sn76489_noise_disabled {
				return Ok(());
			}
		} else if channel == 0x4 && self.config.ym2612_dac_disabled {
  				return Ok(());
  			}

		let mut midi_channel = if channel == 0x4 {
			CHN_DAC
		} else {
			MIDI_CHANNEL_SN76489_BASE + channel.into()
		};
		match register {
			APU_WRA0 | APU_WRB0 | APU_WRD0 => {
				// Volume, Envelope, Hold, Duty Cycle
				if (register == APU_WRA0 || register == APU_WRB0)
					&& (data & 0xF) > 0x0
				{
					self.state.duty_1[channel as usize] =
						self.state.duty_2[channel as usize];
					self.state.duty_2[channel as usize] = (data & 0xC0) / 0x40;

					if self.state.duty_1[channel as usize]
						!= self.state.duty_2[channel as usize]
					{
						let temp_byte = 0x4F
							+ (!self.state.duty_2[channel as usize]
								& 0x3);
						midi.program_change_write(
							midi_channel,
							temp_byte.into(),
						);
					}
				}

				self.state.envelope_1[channel as usize] =
					self.state.envelope_2[channel as usize];
				// output is 1 * envelope
				self.state.envelope_2[channel as usize] = data & 0xF;
				self.state.hold[channel as usize] = data & 0x20;

				if self.state.envelope_1[channel as usize]
					!= self.state.envelope_2[channel as usize]
				{
					if self.state.envelope_2[channel as usize] == 0x0 {
						midi.note_off_write(
							midi_channel,
							self.state.factored.midi_note
								[channel as usize],
							0x00.into(),
						);
						self.state.factored.midi_note[channel as usize] =
							0xFF.into();
						self.state.note_delay[channel as usize] = 10000;
					}
					self.state.factored.midi_volume[0] =
						db_to_midi_vol(lin_to_db(
							self.state.envelope_2[channel as usize]
								as f64 / 0xF as f64,
						));
					midi.controller_write(
						midi_channel,
						MIDI_VOLUME,
						self.state.factored.midi_volume[0],
					);
				}
			},
			APU_WRC0 => {
				self.state.hold[channel as usize] = data & 0x80;
				let temp_byte = data & 0x7F;
				if temp_byte != self.state.tri_len {
					self.state.tri_len = temp_byte;
					self.state.note_en_1[channel as usize] =
						self.state.note_en_2[channel as usize];
					self.state.note_en_2[channel as usize] =
						(self.state.note_en_2[channel as usize] & !0x4)
							| (self.state.tri_len & 0x4);
					if self.state.note_en_1[channel as usize]
						!= self.state.note_en_2[channel as usize] && (self
						.state
						.note_en_2[channel as usize]
						& 0x3)
						== 0x3
					{
						if self.state.note_en_2[channel as usize] & 0x4 != 0
						{
							midi.do_note_on(
								self.state.factored.note_1
									[channel as usize],
								self.state.factored.note_2
									[channel as usize],
								midi_channel,
								&mut self.state.factored.midi_note
									[channel as usize],
								&mut self.state.factored.midi_wheel
									[channel as usize],
								Some(255),
								None,
							)?;
							self.state.note_delay[channel as usize] = 0;
						} else if self.state.factored.midi_note
							[channel as usize] != 0xFF
						{
							// Note got silenced by setting TriLen = 0
							midi.note_off_write(
								midi_channel,
								self.state.factored.midi_note
									[channel as usize],
								0x00.into(),
							);
							self.state.factored.midi_note
								[channel as usize] = 0xFF.into();
							self.state.note_delay[channel as usize] =
								10000;
						}
					}
				}
			},
			APU_WRA1 | APU_WRB1 => (), // Sweep
			//FIXME: impossible to do in this version of vgm2mid
			APU_WRA2 | APU_WRB2 | APU_WRC2 | APU_WRA3 | APU_WRB3 | APU_WRC3 => {
				if (register & 0x3) == 0x2 {
					self.state.factored.fnum_lsb[channel as usize] = data;
				//Exit Sub
				} else if (register & 0x3) == 0x3 {
					self.state.factored.fnum_msb[channel as usize] = data;
					self.state.vblen_1[channel as usize] =
						self.state.vblen_2[channel as usize];
					if self.state.hold[channel as usize] != 0 {
						self.state.vblen_2[channel as usize] = 0xFF;
					} else {
						self.state.vblen_2[channel as usize] =
							(data & 0xF8) / 0x8;
					}
					self.state.note_en_1[channel as usize] =
						self.state.note_en_2[channel as usize];
					self.state.note_en_2[channel as usize] =
						(self.state.note_en_2[channel as usize] & !0x2)
							| self.state.vblen_2[channel as usize]
								& 0x2;
					if self.state.note_en_1[channel as usize]
						!= self.state.note_en_2[channel as usize] && (self
						.state
						.note_en_2[channel as usize]
						& 0x5)
						== 0x5 && self.state.note_en_2[channel as usize] & 0x2 == 0 && self.state.factored.midi_note
								[channel as usize] != 0xFF {
     							// Note got silenced by setting VBLen = 0
     							midi.note_off_write(
     								midi_channel,
     								self.state.factored.midi_note
     									[channel as usize],
     								0x00.into(),
     							);
     							self.state.factored.midi_note
     								[channel as usize] = 0xFF.into();
     							self.state.note_delay[channel as usize] =
     								10000;
     						}
				}
				self.state.factored.fnum_1[channel as usize] =
					self.state.factored.fnum_2[channel as usize];
				self.state.factored.fnum_2[channel as usize] =
					((self.state.factored.fnum_msb[channel as usize] as u32
						& 0x7) << 8) | self.state.factored.fnum_lsb
						[channel as usize] as u32;

				self.state.factored.hz_1[channel as usize] =
					self.state.factored.hz_2[channel as usize];
				self.state.factored.hz_2[channel as usize] =
					hz_nes(self.state.factored.fnum_2[channel as usize]);
				if channel == 0x2 {
					self.state.factored.hz_2[channel as usize] /= 2.0;
				}

				let mut temp_note =
					hz_to_note(self.state.factored.hz_2[channel as usize]);
				if temp_note >= 0x80.into() {
					//TempNote = 0x7F - self.state.factored.fnum_2[channel as usize]
					temp_note = 0x7F.into();
				}
				self.state.factored.note_1[channel as usize] =
					self.state.factored.note_2[channel as usize];
				self.state.factored.note_2[channel as usize] = temp_note;

				if (self.state.note_en_2[channel as usize] & 0x7) == 0x7 {
					if (register & 0x3) == 0x3
						&& self.state.note_delay[channel as usize] > 10
					{
						// writing to register 3 restarts the notes
						midi.do_note_on(
							self.state.factored.note_1
								[channel as usize],
							self.state.factored.note_2
								[channel as usize],
							midi_channel,
							&mut self.state.factored.midi_note
								[channel as usize],
							&mut self.state.factored.midi_wheel
								[channel as usize],
							Some(255),
							None,
						)?;
						self.state.note_delay[channel as usize] = 0;
					} else if self.state.factored.note_1[channel as usize] != self.state.factored.note_2[channel as usize] && midi.do_note_on(
							self.state.factored.note_1
								[channel as usize],
							self.state.factored.note_2
								[channel as usize],
							midi_channel,
							&mut self.state.factored.midi_note
								[channel as usize],
							&mut self.state.factored.midi_wheel
								[channel as usize],
							None,
							None,
						)? {
     							self.state.note_delay[channel as usize] = 0;
     						}
				}
			},
			APU_WRD2 => {
				// Noise Freq
				self.state.factored.fnum_lsb[channel as usize] = data;
				self.state.factored.fnum_1[channel as usize] = shift(
					&mut self.state.factored.fnum_2[channel as usize],
					(self.state.factored.fnum_lsb[channel as usize] & 0xF)
						.into(),
				);

				self.state.factored.hz_1[channel as usize] =
					self.state.factored.hz_2[channel as usize];
				self.state.factored.hz_2[channel as usize] =
					hz_nesnoise(self.state.factored.fnum_2[channel as usize]);

				let mut temp_note =
					hz_to_note(self.state.factored.hz_2[channel as usize]);
				if temp_note >= 0x80.into() {
					//TempNote = 0x7F - self.state.factored.fnum_2[channel as usize]
					temp_note = 0x7F.into();
				}
				self.state.factored.note_1[channel as usize] =
					self.state.factored.note_2[channel as usize];
				self.state.factored.note_2[channel as usize] = temp_note;

				if (self.state.note_en_2[channel as usize] & 0x7) == 0x7 && self.state.factored.note_1[channel as usize]
						!= self.state.factored.note_2[channel as usize] && midi.do_note_on(
						self.state.factored.note_1[channel as usize],
						self.state.factored.note_2[channel as usize],
						midi_channel,
						&mut self.state.factored.midi_note
							[channel as usize],
						&mut self.state.factored.midi_wheel
							[channel as usize],
						None,
						None,
					)? {
    						self.state.note_delay[channel as usize] = 0;
    					}
			},
			APU_WRD3 => {
				self.state.vblen_1[channel as usize] =
					self.state.vblen_2[channel as usize];
				if self.state.hold[channel as usize] != 0 {
					self.state.vblen_2[channel as usize] = 1;
				} else {
					self.state.vblen_2[channel as usize] = (data & 0xF8) / 0x8;
				}
				self.state.note_en_1[channel as usize] =
					self.state.note_en_2[channel as usize];
				self.state.note_en_2[channel as usize] =
					(self.state.note_en_2[channel as usize] & !0x2)
						| (self.state.vblen_2[channel as usize]) & 0x2;
				if self.state.note_en_1[channel as usize]
					!= self.state.note_en_2[channel as usize] && (self.state.note_en_2[channel as usize] & 0x5) == 0x5 && !self.state.note_en_2[channel as usize] & 0x2 != 0 && self.state.factored.midi_note[channel as usize]
							!= 0xFF {
    						// Note got silenced by setting VBLen = 0
    						midi.note_off_write(
    							midi_channel,
    							self.state.factored.midi_note
    								[channel as usize],
    							0x00.into(),
    						);
    						self.state.factored.midi_note[channel as usize] =
    							0xFF.into();
    						self.state.note_delay[channel as usize] = 10000;
    					}

				if (self.state.note_en_2[channel as usize] & 0x7) == 0x7
					&& self.state.note_delay[channel as usize] > 10
				{
					// writing to register 3 restarts the notes
					midi.do_note_on(
						self.state.factored.note_2[channel as usize],
						self.state.factored.note_2[channel as usize],
						midi_channel,
						&mut self.state.factored.midi_note
							[channel as usize],
						&mut self.state.factored.midi_wheel
							[channel as usize],
						Some(255),
						None,
					)?;
					self.state.note_delay[channel as usize] = 0;
				}
			},
			APU_WRE0 => {
				// IRQ, Looping, Frequency
				self.state.factored.fnum_lsb[channel as usize] = data;
				self.state.factored.fnum_1[channel as usize] =
					self.state.factored.fnum_2[channel as usize];
				self.state.factored.fnum_2[channel as usize] =
					(self.state.factored.fnum_lsb[channel as usize] & 0xF)
						.into();

				if self.state.factored.fnum_1[channel as usize]
					!= self.state.factored.fnum_2[channel as usize]
				{
					self.state.factored.midi_wheel[channel as usize] =
						((self.state.factored.fnum_2[channel as usize]
							<< 10) as u16)
							.into(); //(0x4000 / 0x10)
					midi.event_write(Midi {
						channel: midi_channel,
						message: PitchBend {
							bend: midly::PitchBend(
								self.state.factored.midi_wheel
									[channel as usize],
							),
						},
					});
				}
			},
			APU_WRE1 => {
				midi.note_on_write(
					midi_channel,
					(data & 0x7F).into(),
					0x7F.into(),
				);
				midi.note_off_write(
					midi_channel,
					(data & 0x7F).into(),
					0x00.into(),
				);
			},
			APU_WRE2 => {
				self.state.factored.fnum_msb[channel as usize] = data;

				self.state.factored.fnum_1[channel as usize] =
					self.state.factored.fnum_2[channel as usize];
				self.state.factored.note_2[channel as usize] =
					(self.state.factored.fnum_msb[channel as usize] & 0x7F)
						.into();
				// note is activated by setting the DAC bit in APU_SMASK
			},
			APU_SMASK => {
				for channel in 0..=0x4 {
					let temp_byte = 2 ^ channel;
					midi_channel = if channel == 0x4 {
						CHN_DAC
					} else {
						MIDI_CHANNEL_SN76489_BASE + channel.into()
					};
					self.state.factored.note_on_1[channel as usize] =
						self.state.factored.note_on_2[channel as usize];
					self.state.factored.note_on_2[channel as usize] =
						(data & temp_byte) != 0;

					self.state.note_en_1[channel as usize] =
						self.state.note_en_2[channel as usize];
					self.state.note_en_2[channel as usize] =
						(self.state.note_en_2[channel as usize] & !0x1)
							| self.state.factored.note_on_2
								[channel as usize] as u8;
					if self.state.note_en_1[channel as usize]
						!= self.state.note_en_2[channel as usize]
					{
						if (self.state.note_en_2[channel as usize] & 0x7)
							== 0x7
						{
							midi.do_note_on(
								self.state.factored.note_2
									[channel as usize],
								self.state.factored.note_2
									[channel as usize],
								midi_channel,
								&mut self.state.factored.midi_note
									[channel as usize],
								&mut self.state.factored.midi_wheel
									[channel as usize],
								Some(255),
								None,
							)?;
							self.state.note_delay[channel as usize] = 0;
						} else if self.state.factored.midi_note
							[channel as usize] != 0xFF
						{
							// Note got silenced via Channel Mask
							midi.note_off_write(
								midi_channel,
								self.state.factored.midi_note
									[channel as usize],
								0x00.into(),
							);
							self.state.factored.midi_note
								[channel as usize] = 0xFF.into();
							self.state.note_delay[channel as usize] =
								10000;
						}
					}
				}
			},
			_ => strict!(),
		}
		Ok(())
	}
}