#include <stdbool.h>
#include "wait.h"
#include "ps2.h"
#include "ps2_io.h"
#include "debug.h"
#define WAIT(stat, us, err)     \
    do {                        \
        if (!wait_##stat(us)) { \
            ps2_error = err;    \
            goto ERROR;         \
        }                       \
    } while (0)
uint8_t ps2_error = PS2_ERR_NONE;
void ps2_host_init(void) {
    clock_init();
    data_init();
        wait_ms(2500);
    inhibit();
}
uint8_t ps2_host_send(uint8_t data) {
    bool parity = true;
    ps2_error   = PS2_ERR_NONE;
    
    inhibit();
    wait_us(100);  
    
    data_lo();
    clock_hi();
    WAIT(clock_lo, 10000, 10);  
    
    for (uint8_t i = 0; i < 8; i++) {
        wait_us(15);
        if (data & (1 << i)) {
            parity = !parity;
            data_hi();
        } else {
            data_lo();
        }
        WAIT(clock_hi, 50, 2);
        WAIT(clock_lo, 50, 3);
    }
    
    wait_us(15);
    if (parity) {
        data_hi();
    } else {
        data_lo();
    }
    WAIT(clock_hi, 50, 4);
    WAIT(clock_lo, 50, 5);
    
    wait_us(15);
    data_hi();
    
    WAIT(data_lo, 50, 6);
    WAIT(clock_lo, 50, 7);
    
    WAIT(clock_hi, 50, 8);
    WAIT(data_hi, 50, 9);
    inhibit();
    return ps2_host_recv_response();
ERROR:
    inhibit();
    return 0;
}
uint8_t ps2_host_recv_response(void) {
            uint8_t data = 0;
    uint8_t try
        = 250;
    do {
        data = ps2_host_recv();
    } while (try --&&ps2_error);
    return data;
}
uint8_t ps2_host_recv(void) {
    uint8_t data   = 0;
    bool    parity = true;
    ps2_error      = PS2_ERR_NONE;
    
    idle();
    
    WAIT(clock_lo, 100, 1);      WAIT(data_lo, 1, 2);
    WAIT(clock_hi, 50, 3);
    
    for (uint8_t i = 0; i < 8; i++) {
        WAIT(clock_lo, 50, 4);
        if (data_in()) {
            parity = !parity;
            data |= (1 << i);
        }
        WAIT(clock_hi, 50, 5);
    }
    
    WAIT(clock_lo, 50, 6);
    if (data_in() != parity) {
        ps2_error = PS2_ERR_PARITY;
        goto ERROR;
    }
    WAIT(clock_hi, 50, 7);
    
    WAIT(clock_lo, 50, 8);
    WAIT(data_hi, 1, 9);
    WAIT(clock_hi, 50, 10);
    inhibit();
    return data;
ERROR:
    if (ps2_error > PS2_ERR_STARTBIT3) {
        xprintf("x%02X\n", ps2_error);
    }
    inhibit();
    return 0;
}
void ps2_host_set_led(uint8_t led) {
    ps2_host_send(0xED);
    ps2_host_send(led);
}