use chisel_tuto::get_handle;
use cxxrtl::{AsBool, CxxrtlHandle, CxxrtlSignal};
use std::env;

struct Mux4 {
    pub handle: CxxrtlHandle,
    pub io_sel: CxxrtlSignal<2>,
    pub io_in0: CxxrtlSignal<1>,
    pub io_in1: CxxrtlSignal<1>,
    pub io_in2: CxxrtlSignal<1>,
    pub io_in3: CxxrtlSignal<1>,
    pub io_out: CxxrtlSignal<1>,
}

impl Mux4 {
    fn new() -> Self {
        let lib = concat!(env!("OUT_DIR"), "/Mux4.so");
        let handle = get_handle(lib);
        let io_sel = handle.get("io_sel").unwrap().signal();
        let io_in0 = handle.get("io_in0").unwrap().signal();
        let io_in1 = handle.get("io_in1").unwrap().signal();
        let io_in2 = handle.get("io_in2").unwrap().signal();
        let io_in3 = handle.get("io_in3").unwrap().signal();
        let io_out = handle.get("io_out").unwrap().signal();
        Self {
            handle,
            io_sel,
            io_in0,
            io_in1,
            io_in2,
            io_in3,
            io_out,
        }
    }

    fn step(&mut self) {
        self.handle.step()
    }
}

#[test]
fn test_mux4() {
    let mut dut = Mux4::new();
    for s0 in 0..2 {
        for s1 in 0..2 {
            for i0 in 0..2 {
                for i1 in 0..2 {
                    for i2 in 0..2 {
                        for i3 in 0..2 {
                            dut.io_sel.set::<u8>(s1 << 1 | s0);
                            dut.io_in0.set::<bool>(i0.as_bool());
                            dut.io_in1.set::<bool>(i1.as_bool());
                            dut.io_in2.set::<bool>(i2.as_bool());
                            dut.io_in3.set::<bool>(i3.as_bool());
                            dut.step();
                            let out = if s1 == 1 {
                                if s0 == 1 {
                                    i3
                                } else {
                                    i2
                                }
                            } else {
                                if s0 == 1 {
                                    i1
                                } else {
                                    i0
                                }
                            };

                            assert_eq!(dut.io_out.get::<u8>(), out);
                        }
                    }
                }
            }
        }
    }
}