Analyze dependencies of cargo projects
use crate::annotations::{self_profile, timings, unit_graph};

use cargo_metadata::{Metadata, Package, PackageId};
use petgraph::data::{Element, FromElements};
use petgraph::graph::NodeIndex;
use petgraph::matrix_graph::Zero;
use petgraph::visit::EdgeRef;
use petgraph::{Direction, Graph};

#[derive(Debug, Clone, Copy)]
pub enum Measurement {
    Relative,
    Exact,
}

#[derive(Debug, Clone, Copy)]
pub enum Variable {
    UnitDuration,
    TotalDuration,
}

/// A node in the unit graph
#[derive(Debug, Clone)]
pub struct Node<'graph> {
    index: usize,
    metadata: Package,
    unit: &'graph unit_graph::Unit,
    timings: Vec<timings::Message>,
}

impl<'graph> Node<'graph> {
    pub fn index(&self) -> NodeIndex {
        NodeIndex::new(self.index)
    }

    pub fn id(&self) -> &PackageId {
        &self.unit.pkg_id
    }

    pub fn name(&self) -> &str {
        self.metadata.name.as_str()
    }

    pub fn timings(&self) -> &Vec<timings::Message> {
        &self.timings
    }
}

#[derive(Debug)]
pub struct AnnotationGraph<'graph> {
    graph: Graph<Node<'graph>, ()>,
    timings: timings::Output,
    self_profile: self_profile::ProfileCollection,
    unit_graph: &'graph unit_graph::UnitGraph,
}

impl<'graph> AnnotationGraph<'graph> {
    pub fn new(
        metadata: Metadata,
        mut timings: timings::Output,
        self_profile: self_profile::ProfileCollection,
        unit_graph: &'graph unit_graph::UnitGraph,
    ) -> Self {
        dbg!(&unit_graph.roots);
        dbg!(self_profile.probes.len(), self_profile.crates.len());

        let nodes = unit_graph
            .units
            .iter()
            .enumerate()
            .map(|(index, unit)| Element::Node {
                weight: Node {
                    index,
                    metadata: metadata[&unit.pkg_id].clone(),
                    unit,
                    timings: timings.repr.remove(&unit.pkg_id).unwrap(),
                },
            });
        let edges = unit_graph
            .units
            .iter()
            .map(|unit| unit.dependencies.iter())
            .enumerate()
            .map(|(index, deps)| {
                deps.map(move |dep| Element::Edge {
                    source: index,
                    target: dep.index,
                    weight: (),
                })
            })
            .flatten();

        let graph = Graph::from_elements(nodes.chain(edges));

        Self {
            graph,
            timings,
            self_profile,
            unit_graph,
        }
    }

    pub fn roots(&'graph self) -> impl Iterator<Item = &'graph Node> {
        self.unit_graph
            .roots
            .iter()
            .map(|root| NodeIndex::new(*root))
            .map(|index| &self.graph[index])
    }

    pub fn packages(&'graph self) -> impl Iterator<Item = &'graph Node> {
        self.graph.node_indices().map(|index| &self.graph[index])
    }

    pub fn edges(&'graph self) -> impl Iterator<Item = (&'graph Node, &'graph Node)> {
        self.graph
            .edge_indices()
            .map(|edge| self.graph.edge_endpoints(edge))
            .map(Option::unwrap)
            .map(|(source, target)| (&self.graph[source], &self.graph[target]))
    }

    pub fn node_edges(
        &'graph self,
        node: &Node,
        direction: Direction,
    ) -> impl Iterator<Item = &'graph Node> {
        let node_index = node.index();
        self.graph
            .edges_directed(node_index, direction)
            .map(move |edge| match direction {
                Direction::Outgoing => edge.target(),
                Direction::Incoming => edge.source(),
            })
            .map(|index| &self.graph[index])
    }

    pub fn variable(&self, node: &Node, variable: Variable, measurement: Measurement) -> f64 {
        let exact_measurement = match variable {
            Variable::UnitDuration => {
                let timings = &node.timings;
                let durations = timings.iter().map(|msg| msg.duration);

                durations.sum()
            }
            Variable::TotalDuration => {
                let dependencies = self
                    .graph
                    .neighbors_directed(node.index(), Direction::Outgoing);

                let timings = dependencies
                    .map(|dep| &self.graph[dep].timings)
                    .flatten()
                    .map(|msg| msg.duration);

                let self_timings = self.variable(node, Variable::UnitDuration, measurement);
                timings.sum::<f64>() + self_timings
            }
        };

        match measurement {
            Measurement::Exact => exact_measurement,
            Measurement::Relative => {
                let dependents = self
                    .graph
                    .edges_directed(node.index(), Direction::Incoming)
                    .count() as f64;

                if dependents.is_zero() {
                    exact_measurement
                } else {
                    exact_measurement / dependents
                }
            }
        }
    }
}