const std = @import("std");
const microzig = @import("microzig");
const rp2xxx = microzig.hal;

const dcc = @import("dcc.zig");
const main = @import("main.zig");

const log = std.log.scoped(.motor);

pub const MotorDriver = union(enum) {
    l293: L293,
    bdr6133: BDR6133,
};

const L293 = struct {
    fwd: rp2xxx.gpio.Pin,
    bwd: rp2xxx.gpio.Pin,
    speed: rp2xxx.pwm.Pwm,
    fspwm: StartPWM = StartPWM{},
    bspwm: StartPWM = StartPWM{},

    fn stop(self: @This()) void {
        self.fwd.put(1);
        self.bwd.put(1);
        log.debug("Motor stopped", .{});
    }

    fn freerun(self: @This()) void {
        self.speed.set_level(0);
        log.debug("Motor in freerun mode", .{});
    }

    fn forward(self: @This(), level: u16) void {
        self.fwd.put(1);
        self.bwd.put(0);
        self.speed.set_level(level);
        log.debug("Motor is moving forward: {d}", .{level});
    }

    fn backward(self: @This(), level: u16) void {
        self.fwd.put(0);
        self.bwd.put(1);
        self.speed.set_level(level);
        log.debug("Motor is moving backward: {d}", .{level});
    }
};

const BDR6133 = struct {};

pub fn measureEMCOffset(comptime md: MotorDriver, pin: rp2xxx.adc.Input) !u12 {
    defer rp2xxx.adc.set_enabled(false);
    rp2xxx.adc.set_enabled(true);

    const ITERATIONS = 50;
    const MEASUREMENT_DELAY_MS = 500;

    switch (md) {
        .l293 => md.l293.stop(),
        .bdr6133 => md.bdr6133.stop(),
    }
    rp2xxx.time.sleep_ms(MEASUREMENT_DELAY_MS);

    rp2xxx.adc.select_input(pin);

    rp2xxx.adc.start(.free_running);
    var sum: usize = 0;
    for (0..ITERATIONS) |_| {
        while (rp2xxx.adc.fifo.is_empty()) {
            asm volatile ("" ::: "memory");
        }
        sum += try rp2xxx.adc.fifo.pop();
    }

    const ret: u12 = @truncate(sum / ITERATIONS);

    log.debug("EMC offset of {} is: {d}", .{ pin, ret });

    return ret;
}

fn emf(comptime md: MotorDriver, dir: dcc.Direction, level: u16) !u12 {
    defer rp2xxx.adc.set_enabled(false);
    rp2xxx.adc.set_enabled(true);

    const emf_config = main.cv.manufacturer.emf;

    switch (md) {
        .l293 => md.l293.freerun(),
        .bdr6133 => md.bdr6133.freerun(),
    }
    rp2xxx.time.sleep_us(emf_config.delay_us);

    rp2xxx.adc.start(.free_running);
    var results: [std.math.maxInt(u8)]u12 = undefined;
    for (0..emf_config.iterations) |i| {
        while (rp2xxx.adc.fifo.is_empty()) {
            asm volatile ("" ::: "memory");
        }
        results[i] = try rp2xxx.adc.fifo.pop();
        // TODO: check if we should sort right here
    }

    switch (md) {
        .l293 => {
            switch (dir) {
                .forward => md.l293.forward(level),
                .backward => md.l293.backward(level),
            }
        },
        .bdr6133 => {
            switch (dir) {
                .forward => md.bdr6133.forward(level),
                .backward => md.bdr6133.backward(level),
            }
        },
    }

    std.sort.pdq(u12, results[0..emf_config.iterations], {}, less);

    // Calculate average EMF
    var sum: usize = 0;
    for (emf_config.low_cutoff..emf_config.iterations - emf_config.high_cutoff) |i| {
        sum += results[i];
    }
    const aemf = sum / (emf_config.iterations - emf_config.low_cutoff - emf_config.high_cutoff);

    return @truncate(aemf);
}

const StartPWM = struct {
    const START_PWM_SIZE = 10;
    // TODO: Initial values could be 2/3 of pwm_wrap
    history: [START_PWM_SIZE]u16 = [_]u16{0} ** START_PWM_SIZE, // PWM wrap is u16
    next: u8 = 0, // CV is u8

    fn save(self: *@This(), val: u16) void {
        self.history[self.next] = val;
        self.next +%= 1;
    }

    fn average(self: @This()) u16 {
        var sum: usize = 0;
        for (0..START_PWM_SIZE) |i| {
            sum += self.history[i];
        }

        return @truncate(sum / START_PWM_SIZE);
    }
};

pub fn tester(comptime md: MotorDriver, pwm_period: u16) !void {
    var level: usize = @as(usize, pwm_period) * 65 / 100;
    while (level < pwm_period) : (level += pwm_period / 50) {
        switch (md) {
            .l293 => md.l293.forward(@intCast(level)),
            .bdr6133 => md.bdr6133.forward(@intCast(level)),
        }
        rp2xxx.time.sleep_ms(2000);

        { // Measure Back-EMF
            rp2xxx.adc.select_input(main.pins.m1emf);

            const m1emf = try emf(md, .forward, @intCast(level));

            std.log.debug("Motor1 average EMF: {d}", .{m1emf});
        }
    }
}

fn less(_: void, a: u12, b: u12) bool {
    return a < b;
}